#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# cert_utils.py  -  Utilities for certificate management
#
# Copyright (C) 2015 Jan Jockusch <jan.jockusch@perfact.de>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
#


from .generic import safe_syscall, cleanup_string
from .generic import split_unescape, string_escape
from .generic import to_string
from .file import tmp_read, tmp_readstr, tmp_write

import tempfile
import time
import string
import os
import sys
import shutil
import six

# Python 2/3 compatibility
if six.PY2:
    from .compat_tempfile import TemporaryDirectory
else:
    from tempfile import TemporaryDirectory

if not six.PY2:
    long = int


# Generic webserver certificate management


def webserver_valid_path(path):
    # "Deny" tests first:
    if '..' in path.split('/'):
        return False
    # Then "Allow" tests:
    if path.startswith('/etc/haproxy/ssl/'):
        return True
    if path.startswith('/etc/apache2/ssl.crt/'):
        return True
    # "Deny" by default.
    return False


def webserver_cert_gen(subj, keyfile, certfile):
    '''Generate a private key and self-signed certificate.'''
    assert webserver_valid_path(keyfile)
    assert webserver_valid_path(certfile)
    retcode, output = safe_syscall([
        'openssl', 'req', '-utf8', '-newkey', 'rsa:4096',
        '-keyout', keyfile,
        '-x509', '-days', '73000',
        '-out', certfile,
        '-nodes',
        '-subj', to_string(subj)], raisemode=True)
    return


def webserver_cert_csr(subj, keyfile, csrfile, alt_dns=None):
    '''Generate a certificate signing request for a given
    subject and key. Store it into the given csrfile path.
    '''
    assert webserver_valid_path(keyfile)
    assert webserver_valid_path(csrfile)

    # Create section with alt_names, including the main CN
    alt_dns = alt_dns or []
    subj_dict = subj_to_dict(subj)
    if subj_dict.get('CN') and subj_dict['CN'] not in alt_dns:
        alt_dns.append(subj_dict['CN'])
    alt_names = cert_altnames(alt_dns=alt_dns)

    dirname = os.path.dirname(csrfile)
    confpath = os.path.join(dirname, 'csr.conf')
    config = ca_server_conf_template % dict(alt_names=alt_names)
    with open(confpath, 'w') as f:
        f.write(config)

    retcode, output = safe_syscall([
        'openssl', 'req', '-new', '-utf8',
        '-key', keyfile,
        '-subj', to_string(subj),
        '-out', csrfile,
        '-config', confpath,
    ], raisemode=True)
    with open(csrfile, 'r') as f:
        csrdata = f.read()

    os.remove(confpath)
    return csrdata


def webserver_cert_check(certfile, keyfile):
    '''Test if the checksum of the modulus of the key and
    the certificate are the same. This should be enough to
    prove that the certificate and the key match.
    '''
    assert webserver_valid_path(keyfile)
    assert webserver_valid_path(certfile)
    retcode, output = safe_syscall([
        'openssl', 'x509', '-noout', '-modulus',
        '-in', certfile], raisemode=True)
    certmodulus = output.strip()
    retcode, output = safe_syscall([
        'openssl', 'rsa', '-noout', '-modulus',
        '-in', keyfile], raisemode=True)
    keymodulus = output.strip()
    return (keymodulus == certmodulus)


# HAproxy certificate management

haproxy_ssl_dir = '/etc/haproxy/ssl'
haproxy_ssl_cert = 'cert.pem'
haproxy_ssl_key = 'key.pem'
haproxy_ssl_csr = 'csr.pem'
haproxy_ssl_certkey = 'certkey.pem'
haproxy_ssl_cafile = 'cacert.pem'


def haproxy_restart():
    '''Restart the HAproxy server using systemd.'''
    # For the following line to work, "sudoers" has to contain:
    # zope ALL=NOPASSWD:/bin/systemctl restart haproxy
    safe_syscall([
        'sudo', '/bin/systemctl', 'restart', 'haproxy'
    ], raisemode=True)
    return


def haproxy_cert_gen(subj):
    '''Generate a self-signed certificate and place it in the haproxy
    directory.'''
    certfile = haproxy_ssl_dir + '/' + haproxy_ssl_cert
    keyfile = haproxy_ssl_dir + '/' + haproxy_ssl_key
    certkeyfile = haproxy_ssl_dir + '/' + haproxy_ssl_certkey
    webserver_cert_gen(subj, keyfile, certfile)
    # Concatenate the files into one
    certdata = open(certfile, 'r').read()
    keydata = open(keyfile, 'r').read()
    open(certkeyfile, 'w').write(certdata + keydata)
    return


def haproxy_cert_csr(subj, alt_dns=None):
    '''Generate a CSR in the haproxy directory.'''
    keyfile = haproxy_ssl_dir + '/' + haproxy_ssl_key
    csrfile = haproxy_ssl_dir + '/' + haproxy_ssl_csr
    return webserver_cert_csr(
        subj=subj,
        keyfile=keyfile,
        csrfile=csrfile,
        alt_dns=alt_dns,
    )


def haproxy_cert_import(certdata, keydata=None, chaindata=None):
    '''Import the files:
    certdata: representing the certificate,
    keydata: representing the private-key (optional),
    chaindata: representing the chainfile of concatenated
               CA-certificates (optional)
    '''
    keyfile = haproxy_ssl_dir + '/' + haproxy_ssl_key
    certfile = haproxy_ssl_dir + '/' + haproxy_ssl_cert
    certkeyfile = haproxy_ssl_dir + '/' + haproxy_ssl_certkey

    # input cleanup
    certdata = to_string(certdata)
    keydata = to_string(keydata)
    chaindata = to_string(chaindata)

    # Prevent attackers from flooding the system
    MAX_CERTLEN = 16384
    assert len(certdata) < MAX_CERTLEN
    assert len(keydata or '') < MAX_CERTLEN
    assert len(chaindata or '') < MAX_CERTLEN

    backup_suffix = '.backup_{}'.format(time.strftime('%Y-%m-%dT%H:%M:%S'))
    newfile_suffix = '.new'

    # Make a backup. NOTE: All cert files are text, so we stick to 'r' and 'w'
    for fname in keyfile, certfile, certkeyfile:
        data = open(fname, 'r').read()
        open(fname + backup_suffix, 'w').write(data)

    if not keydata:
        keydata = open(keyfile, 'r').read()

    # Write new certificates
    open(certfile + newfile_suffix, 'w').write(certdata)
    open(keyfile + newfile_suffix, 'w').write(keydata)
    open(certkeyfile + newfile_suffix, 'w').write(
        certdata + keydata + (chaindata or ''))

    check_ok = webserver_cert_check(
        certfile=certfile + newfile_suffix,
        keyfile=keyfile + newfile_suffix,
    )
    assert check_ok, "Moduli of key and certificate do not match."

    # Move the files into place
    for fname in keyfile, certfile, certkeyfile:
        data = open(fname + newfile_suffix, 'r').read()
        open(fname, 'w').write(data)
    return


def haproxy_cafile_upload(cacertdata):
    '''Copy the CA certificate data into the SSL directory of the local
    HAproxy server.

    The parameter "cacertdata" must be passed as PEM ASCII text.
    '''
    cacertfile = haproxy_ssl_dir + '/' + haproxy_ssl_cafile

    # input cleanup
    cacertdata = to_string(cacertdata)

    # Prevent potential attackers from flooding the file system
    MAX_CERTLEN = 1000000
    assert len(cacertdata) < MAX_CERTLEN

    backup_suffix = '.backup_{}'.format(time.strftime('%Y-%m-%dT%H:%M:%S'))

    # Make a backup. NOTE: All cert files are text, so we stick to 'r' and 'w'
    data = open(cacertfile, 'r').read()
    open(cacertfile + backup_suffix, 'w').write(data)

    # Write new file
    with open(cacertfile, 'w') as fh:
        fh.write(cacertdata)
    return


# Apache certificate management

apache_ssl_dir = '/etc/apache2/ssl.crt'
apache_ssl_key = 'server.key'
apache_ssl_crt = 'server.crt'
apache_ssl_chain = 'cachain.crt'
apache_ssl_csr = 'server.csr'


def apache_cert_backup(failsafe=False):
    '''Make a backup of certificate files.

    Returns the suffix used. The suffix may be used to restore from a
    backup.
    '''
    if failsafe:
        suffix = '_failsafe'
    else:
        suffix = time.strftime('_%y-%m-%d_%H:%M:%S')
    for filename in [apache_ssl_key, apache_ssl_crt, apache_ssl_chain]:
        f = os.path.join(apache_ssl_dir, filename)
        shutil.copyfile(f, f+suffix)
    return suffix


def apache_cert_link_chainfile():
    '''Create a link to the certificate chain in the apache server
    directory.
    '''
    retcode, output = safe_syscall([
        'rm', apache_ssl_dir+'/'+apache_ssl_chain], raisemode=True)
    retcode, output = safe_syscall([
        'ln', '-s',
        apache_ssl_dir+'/'+apache_ssl_crt,
        apache_ssl_dir+'/'+apache_ssl_chain], raisemode=True)
    return


def apache_restart_server():
    '''Restart the apache server via apache2ctl.
    '''
    retcode, output = safe_syscall([
        'sudo', '/usr/sbin/apache2ctl', 'restart'], raisemode=True)
    return


def apache_cert_restore(suffix):
    '''Restore certificate files from a backup.
    '''
    suffix = cleanup_string(suffix,
                            valid_chars=string.digits+string.letters+'-_.:')

    for filename in [apache_ssl_key, apache_ssl_crt, apache_ssl_chain]:
        retcode, output = safe_syscall([
            'cp',
            apache_ssl_dir+'/'+filename+suffix,
            apache_ssl_dir+'/'+filename], raisemode=True)
    return


def apache_cert_gen(subj):
    '''Generate a new private key and a self-signed certificate
    afterwards. The certificate will be valid for a long time (like
    200 years)
    '''
    webserver_cert_gen(
        subj,
        keyfile=apache_ssl_dir+'/'+apache_ssl_key,
        certfile=apache_ssl_dir+'/'+apache_ssl_crt,
    )
    apache_cert_link_chainfile()
    apache_restart_server()
    return


def apache_cert_csr(subj):
    '''Generate a certificate signing request for the given subject
    (CommonName etc.) and store it in exchdir
    '''

    exch_dir = tempfile.mkdtemp()
    webserver_cert_csr(
        subj,
        keyfile=apache_ssl_dir+'/'+apache_ssl_key,
        csrfile=exch_dir+'/'+apache_ssl_csr,
    )
    csrfile = tmp_readstr(exch_dir + '/server.csr')
    return csrfile


def apache_cert_import(certfile, keyfile=None, chainfile=None):
    '''Import the files:
    "server.crt" representing the certificate,
    "server.key" representing the private-key (optional),
    "cachain.crt" representing the chainfile of concatenated
    CA-certificates (optional)
    '''
    # Make a backup
    suffix = apache_cert_backup()

    for filename, data in [
            (apache_ssl_key, keyfile),
            (apache_ssl_crt, certfile),
            (apache_ssl_chain, chainfile)]:
        if not data:
            continue
        fh = open(apache_ssl_dir+'/'+filename, 'w')
        fh.write(data)
        fh.close()

    if certfile and not chainfile:
        apache_cert_link_chainfile()

    check_ok = webserver_cert_check(
        certfile=apache_ssl_dir+'/'+apache_ssl_crt,
        keyfile=apache_ssl_dir+'/'+apache_ssl_key,
    )
    if not check_ok:
        apache_cert_restore(suffix)
        assert False, "Hashes of key and certificate do not match."

    apache_restart_server()
    return


# Zope CA management

ca_path_base = "/home/zope/CA"
ca_path_fmt = ca_path_base + "/%d"
# OpenSSL lockfile. Every action taken inside the CA directory must be
# serialized. Parallel actions are not supported. The lockfile makes sure
# to only let one process work inside the CA directory at a time.
openssl_lockfile = 'openssl-running.lock'

ca_conf_template = """
[ ca ]
default_ca = ca_default

[ ca_default ]
dir = %s
certs = $dir/certs
crl_dir = $dir/crl
database = $dir/index.txt
unique_subject = no
new_certs_dir = $dir/certs
certificate = $dir/certs/cacert.pem
serial = $dir/serial
crl = $dir/crl/crl.pem
private_key = $dir/private/cakey.pem
name_opt = ca_default
cert_opt = ca_default
x509_extensions = ca_ext
copy_extensions = copyall
default_days = 1826
default_crl_days = 182
default_md = sha512
policy = ca_policy

[ ca_policy ]
countryName             = optional
stateOrProvinceName     = optional
organizationName        = optional
organizationalUnitName  = optional
commonName              = supplied
emailAddress            = optional

[ ca_ext ]
subjectKeyIdentifier=hash
authorityKeyIdentifier=keyid:always,issuer:always
subjectAltName=email:copy
issuerAltName=issuer:copy
keyUsage=critical,cRLSign,keyCertSign
nsCertType=sslCA, emailCA
basicConstraints=critical,CA:true
nsComment=PerFact Zope CA

[ req ]
distinguished_name = dn
req_extensions = ca_ext

[ dn ]
"""

ca_server_conf_template = """
[ req ]
distinguished_name      = dn
req_extensions          = req_ext

[ req_ext ]
keyUsage = keyEncipherment, digitalSignature
extendedKeyUsage = serverAuth
basicConstraints = CA:FALSE
nsCertType = server
%(alt_names)s

[ dn ]
"""

ca_user_conf_template = """
[ req ]
distinguished_name      = dn
req_extensions          = req_ext

[ req_ext ]
keyUsage = nonRepudiation, digitalSignature, keyEncipherment, keyAgreement
extendedKeyUsage = clientAuth, emailProtection
basicConstraints = CA:FALSE
nsCertType = client, email
%(alt_names)s

[ dn ]
"""

ca_ovpnclient_conf_template = """
[ req ]
distinguished_name      = dn
req_extensions          = req_ext

[ req_ext ]
keyUsage = digitalSignature
extendedKeyUsage = clientAuth
basicConstraints = CA:FALSE
nsCertType = client

[ dn ]
"""

ca_sign_conf_template = """
[ req ]
distinguished_name      = dn
req_extensions          = req_ext

[ req_ext ]
keyUsage = digitalSignature
basicConstraints = CA:FALSE

[ dn ]
"""


def ca_initialize(subj, cadays=8000, passphrase=None, appca_id=1,
                  ca_path_fmt=ca_path_fmt):
    '''Create new configuration files for OpenSSL
    using the data from 'subj'. Create a new private key and
    certificate for the CA and generate all necessary files and
    directories.
    '''

    ca_path = ca_path_fmt % int(appca_id)

    # Make a backup of the current infrastructure
    suffix = time.strftime('_%y%m%d_%H%M%S')
    safe_syscall(['mv', ca_path, ca_path+suffix], raisemode=False)

    # Build ca file structure
    safe_syscall(['mkdir', '-p', ca_path], raisemode=True)
    safe_syscall(['chmod', '700', ca_path], raisemode=True)

    safe_syscall(['touch', ca_path+'/index.txt'], raisemode=True)
    # write the serial numer 01 into the serial file
    fh = open(ca_path + '/serial', 'w')
    fh.write('01\n')
    fh.close()

    # Create the configuration
    fh = open(ca_path + '/openssl.conf', 'w')
    fh.write(ca_conf_template % ca_path)
    fh.close()

    # Create all required directories
    dirs = ['certs', 'private', 'crl', 'certified_req', 'req', 'revoked']
    for dirname in dirs:
        safe_syscall(['mkdir', '-p', '-m', '700', ca_path+'/'+dirname],
                     raisemode=True)

    # Generate the private key.
    key = cert_makekey(passphrase=passphrase)
    fh = open(ca_path+'/private/cakey.pem', 'w')
    fh.write(key)
    fh.close()

    opts = []
    with TemporaryDirectory() as tmpdir:
        if passphrase:
            passname = tmpdir+'/pass'
            tmp_write(passname, passphrase)
            opts.extend(['-passin', 'file:'+passname])

    # Use openssl to generate the cacert.pem file
        safe_syscall(['openssl', 'req',
                      '-config', ca_path+'/openssl.conf',
                      '-new', '-x509', '-utf8',
                      '-days', to_string(cadays),
                      '-subj', to_string(subj),
                      '-extensions', 'ca_ext'] +
                     opts +
                     ['-key', ca_path+'/private/cakey.pem',
                      '-out', ca_path+'/certs/cacert.pem'], raisemode=True)
    return key


def ca_destroy(appca_id):
    '''Remove a CA folder and all its contained files, if it exists.
    '''

    ca_path = ca_path_fmt % int(appca_id)

    # check if the file exists, then delete it
    if os.path.exists(ca_path):
        safe_syscall(['rm', '-rf', ca_path], raisemode=True)
        return True

    return False


def ca_purge(excluded_ids=[]):
    '''Purge all certificates except for a list of excluded ones.
    Used on template rollout to clean up the CA folder.
    '''
    # retrieve all files in CA folder
    cmd = ['find', ca_path_base,
           '-maxdepth', '1',
           '-mindepth', '1',
           '-type', 'd']
    retcode, output = safe_syscall(cmd, raisemode=True)

    # format excluded paths
    excluded_paths = [ca_path_fmt % int(id) for id in excluded_ids]

    # iterate results and remove
    for path in output.split('\n'):
        if path and path not in excluded_paths:
            safe_syscall(['rm', '-rf', path], raisemode=True)
    return retcode == 0


def ca_signreq(csr, days=None, expire=None,
               passphrase=None, appca_id=1, start=None,
               flock_timeout=5, ca_path_fmt=ca_path_fmt):
    '''Use the CA to sign the given request.

    Either "days" or "expire" may be supplied. Days is a simple
    integer, while "expire" has to be a openssl compatible string of
    the format '%y%m%d%H%M%SZ'.
    If the certificate should not be valid starting from 'now()',
    a "start" parameter with openssl compatible string of the format
    '%y%m%d%H%M%SZ' can be given.

    Returns the certificate.

    '''
    ca_path = ca_path_fmt % int(appca_id)
    lockfile = os.path.join(ca_path, openssl_lockfile)
    lock_cmds = ['flock', '--wait', str(flock_timeout), lockfile]

    opts = []
    if days:
        opts.extend(['-days', '%d' % days])

    if start:
        assert (len(start) in (13, 15) and
                start[:-1].isdigit() and
                start[-1] == 'Z'), "Illegal format in start."
        opts.extend(['-startdate', start])

    if expire:
        assert (len(expire) in (13, 15) and
                expire[:-1].isdigit() and
                expire[-1] == 'Z'), "Illegal format in expire."
        opts.extend(['-enddate', expire])

    with TemporaryDirectory() as tmpdir:
        csrfile = tmpdir + '/in.csr'
        certfile = tmpdir + '/cert.pem'
        passfile = tmpdir + '/pass'

        if passphrase:
            tmp_write(passfile, passphrase)
            opts.extend(['-passin', 'file:'+passfile])

        tmp_write(csrfile, csr)
        safe_syscall(
            lock_cmds +
            ['openssl', 'ca',
             '-config', ca_path+'/openssl.conf',
             '-in', csrfile,
             '-batch'] +
            opts +
            ['-out', certfile], raisemode=True)

        cert = tmp_readstr(certfile)
    return cert


def ca_makep12(key, cert, in_passphrase=None, out_passphrase=None,
               appca_id=1, certpbe='AES-256-CBC', keypbe='AES-256-CBC',
               include_cacert=True):
    '''Given the private key and the certificate, pack a p12 file
    containing all information for inclusion in a browser or other
    client. The private key may be protected, so we need to receive
    the passphrase as input. Optionally, protect the output with a
    passphrase.

    Returns the PKCS#12 (p12) certificate containing private key,
    signed certificate and CA certificate.

    '''
    ca_path = ca_path_fmt % int(appca_id)

    with TemporaryDirectory() as tmpdir:
        keyfile = tmpdir+'/key.pem'
        certfile = tmpdir+'/cert.pem'
        p12file = tmpdir+'/out.p12'
        passin = tmpdir+'/passin'
        passout = tmpdir+'/passout'

        opts = []
        if in_passphrase:
            tmp_write(passin, in_passphrase)
            opts.extend(['-passin', 'file:'+passin])
        if out_passphrase:
            tmp_write(passout, out_passphrase)
            opts.extend(
                ['-keypbe', keypbe,
                 '-passout', 'file:'+passout, ])
        tmp_write(keyfile, key)

        if cert:
            tmp_write(certfile, cert)
            if include_cacert:
                opts.extend(
                    ['-certpbe', certpbe,
                     '-in', certfile,
                     '-certfile', ca_path+'/certs/cacert.pem',
                     '-name', 'PerFact Certificate',
                     '-caname', 'PerFact CA Certificate'])
            else:
                opts.extend(
                    ['-certpbe', certpbe,
                     '-in', certfile,
                     '-name', 'PerFact Certificate'])
        else:
            opts.extend(['-nocerts'])

        safe_syscall(
            ['openssl', 'pkcs12',
             '-export',
             '-inkey', keyfile,
             '-out', p12file] + opts, raisemode=True)

        p12data = tmp_read(p12file)
    return p12data


def ca_extend(days=1826, passphrase=None, appca_id=1):
    '''Extend the validity of the CA certificate by the given amount of
    days (5 years by default)'''
    ca_path = ca_path_fmt % int(appca_id)

    with TemporaryDirectory() as tmpdir:
        passname = tmpdir+'/pass'

        opts = []
        if passphrase:
            tmp_write(passname, passphrase)
            opts.extend(['-passin', 'file:'+passname])

        safe_syscall(
            ['openssl', 'x509',
             '-days', str(int(days)),
             '-in', ca_path+'/certs/cacert.pem',
             '-out', ca_path+'/certs/cacert_new.pem',
             '-signkey', ca_path+'/private/cakey.pem', ] + opts,
            raisemode=True)

        safe_syscall(
            ['mv', ca_path+'/certs/cacert_new.pem',
             ca_path+'/certs/cacert.pem', ], raisemode=True)
    return


def ca_revokecert(cert, passphrase=None, appca_id=1):
    '''Revoke the given certificate.'''
    ca_path = ca_path_fmt % int(appca_id)

    with TemporaryDirectory() as tmpdir:
        passname = tmpdir+'/pass'
        certname = tmpdir+'/cert.pem'

        lockfile = os.path.join(ca_path, openssl_lockfile)
        lock_cmds = ['flock', '--wait', '5', lockfile]

        opts = []
        tmp_write(certname, cert)
        if passphrase:
            tmp_write(passname, passphrase)
            opts.extend(['-passin', 'file:'+passname])

        safe_syscall(
            lock_cmds +
            ['openssl', 'ca',
             '-revoke', certname,
             '-config', ca_path+'/openssl.conf']+opts,
            raisemode=True)
    return


def ca_getcert(appca_id=1):
    '''Return the current CA certificate.'''
    ca_path = ca_path_fmt % int(appca_id)

    fh = open(ca_path+'/certs/cacert.pem', 'r')
    cacert = fh.read()
    fh.close()
    return cacert


# Client-side utilities

def cert_makekey(passphrase=None, rsabits=4096):
    '''Generate a new private key , optionally protecting it with a passphrase.

    Returns a tuple of PEM files: private, public.'''
    with TemporaryDirectory() as tmpdir:
        keyfile = tmpdir + '/key.pem'
        passfile = tmpdir + '/pass'

        priv_opts = []
        if passphrase:
            tmp_write(passfile, passphrase)
            priv_opts = ['-aes256', '-passout', 'file:'+passfile]

        retcode, output = safe_syscall(
            ['openssl', 'genrsa', ] +
            priv_opts +
            ['-out', keyfile,
             '%d' % rsabits, ], raisemode=True)

        privkey = tmp_readstr(keyfile)
    return privkey


def cert_changepassphrase(key, passphrase=None, new_passphrase=None):
    '''Re-encrypt the given key with a new passphrase.'''
    with TemporaryDirectory() as tmpdir:
        keyfile = tmpdir + '/key.pem'
        newkeyfile = tmpdir + '/keynew.pem'
        passfile = tmpdir + '/pass'
        newpassfile = tmpdir + '/passnew'

        priv_opts = []
        tmp_write(keyfile, key)
        if passphrase:
            tmp_write(passfile, passphrase)
            priv_opts += ['-passin', 'file:'+passfile, ]
        if new_passphrase:
            tmp_write(newpassfile, new_passphrase)
            priv_opts += ['-aes256', '-passout', 'file:'+newpassfile, ]

        retcode, output = safe_syscall(
            ['openssl', 'rsa', ] +
            priv_opts +
            ['-in', keyfile, '-out', newkeyfile, ],
            raisemode=True
        )

        privkey = tmp_readstr(newkeyfile)
    return privkey


def cert_getpub(key, passphrase=None):
    '''Extract the public key part from a private key.'''
    with TemporaryDirectory() as tmpdir:
        keyfile = tmpdir + '/key.pem'
        pubfile = tmpdir + '/key.pub'
        passfile = tmpdir + '/pass'

        opts = []
        if passphrase:
            tmp_write(passfile, passphrase)
            opts.extend(['-passin', 'file:'+passfile])
        tmp_write(keyfile, key)

        retcode, output = safe_syscall(
            ['openssl', 'rsa',
             '-in', keyfile, ] +
            opts +
            ['-pubout', '-out', pubfile, ], raisemode=True)
        pubkey = tmp_readstr(pubfile)
    return pubkey


def cert_altnames(alt_dns=None, alt_emails=None):
    """
    Create addition to request config for alternative names
    :alt_dns: List of DNS names
    :alt_emails: List of emails for user certificates
    If both arguments are empty, simply returns an empty string. Otherwise, it
    returns the reference to the [alt_names] section as well as that section,
    to be placed at the end of the [req_ext] section.
    """
    result = []
    result.extend([
        'email.%d = %s' % (i+1, email)
        for i, email in enumerate(alt_emails or [])
    ])
    result.extend([
        'DNS.%d = %s' % (i+1, dns)
        for i, dns in enumerate(alt_dns or [])
    ])
    if not result:
        return ''

    return (
        "subjectAltName=@alt_names\n"
        "\n"
        "[alt_names]\n"
    ) + '\n'.join(result)


def cert_makecsr(key, subj, request_type='server', passphrase=None,
                 alt_emails=None, alt_dns=None):
    '''Given a public key (or a full key) and a subject string, generate a
    certificate signing request (CSR).

    Returns the CSR.
    '''
    alt_names = cert_altnames(alt_dns=alt_dns, alt_emails=alt_emails)

    conf = {
        'server': ca_server_conf_template % locals(),
        'user': ca_user_conf_template % locals(),
        'sign': ca_sign_conf_template,
        'ovpnclient': ca_ovpnclient_conf_template,
    }.get(request_type, None)
    assert conf, "Illegal request type."

    with TemporaryDirectory() as tmpdir:
        keyfile = tmpdir+'/key.pem'
        csrfile = tmpdir+'/out.csr'
        passfile = tmpdir+'/pass'
        conffile = tmpdir+'/openssl.conf'

        tmp_write(keyfile, key)
        tmp_write(conffile, conf)

        pass_opts = []
        if passphrase:
            tmp_write(passfile, passphrase)
            pass_opts = ['-passin', 'file:'+passfile]

        retcode, output = safe_syscall(
            ['openssl', 'req', '-new', '-utf8',
             '-config', conffile,
             '-key', keyfile, ] +
            pass_opts +
            ['-subj', to_string(subj),
             '-out', csrfile, ], raisemode=True)

        csr = tmp_readstr(csrfile)
    return csr


def cert_totext(data, fmt='x509', passphrase=None):
    '''Return a text version of the certificate.
    '''
    with TemporaryDirectory() as tmpdir:
        fname = tmpdir+'/file.pem'
        tmp_write(fname, data)

        opts = []
        passname = tmpdir+'/pass'
        if passphrase:
            tmp_write(passname, passphrase)
            opts.extend(['-passin', 'file:'+passname])

        allowed_fmts = ['req', 'x509', 'pkcs12', 'rsa', 'rsapub']
        assert fmt in allowed_fmts, "Illegal format requested."
        # Special rule for extracting public keys
        if fmt == 'rsapub':
            opts.append('-pubin')
            fmt = 'rsa'

        if fmt == 'pkcs12':
            opts.append('-info')
        else:
            opts.append('-text')

        retval, output = safe_syscall(
            ['openssl', fmt, '-noout',
             '-in', fname] + opts, raisemode=True)
    return output


def cert_fingerprint(data, hash='sha1', passphrase=None):
    '''Return the hexadecimal fingerprint of the certificate.
    '''
    with TemporaryDirectory() as tmpdir:
        fname = tmpdir + '/file.pem'
        tmp_write(fname, data)

        allowed_hashes = ['sha1', 'md5', 'sha256', ]
        assert hash in allowed_hashes, "Illegal hash type requested."
        retval, output = safe_syscall(
            ['openssl', 'x509', '-noout',
             '-in', fname,
             '-fingerprint', '-' + hash], raisemode=True)
    # Extract only the hexadecimal digits.
    hex_only = output.strip().split('=', 1)[-1].replace(':', '')
    return hex_only


def cert_encryptdata(data, cert, form='DER'):
    '''Encrypt data using the public key from the indicated x509
    certificate, returning the encrypted data.

    This method uses streaming to encrypt the data and does not store
    the unencrypted data.

    "data" is passed as binary.  A sanity limit of 1.0GiB is imposed.
    "form" may be one of "SMIME", "DER", "PEM".
    '''
    with TemporaryDirectory() as tmpdir:
        certfile = tmpdir + '/cert.pem'

        assert form in ['DER', 'PEM', 'SMIME'], "Illegal format used."
        assert len(data) < (1024**3), "Input length limit exceeded."
        infile = tmpdir + '/in.dat'
        outfile = tmpdir + '/out.dat'

        tmp_write(infile, data)
        tmp_write(certfile, cert)

        retcode, output = safe_syscall(
            ['openssl', 'smime',
             '-encrypt', '-binary', '-aes-256-cbc',
             '-in', infile, '-out', outfile,
             '-outform', form, certfile,
             ], raisemode=True)
        encrypted = tmp_read(outfile)
    return encrypted


def cert_decryptdata(data, key, passphrase=None, form='DER'):
    '''Decrypt data using the private key given. If a passphrase is
    passed, it is used to unlock the key.

    The decrypted data is not stored on disk, but passed on
    immediately. This means that the output must be small enough to
    reside in RAM. A sanity limit of 1.0GiB is imposed.

    "data" may be passed as binary or as an object supporting read().
    "form" may be one of "SMIME", "DER", "PEM".
    '''
    with TemporaryDirectory() as tmpdir:
        keyfile = tmpdir + '/key.pem'
        passfile = tmpdir + '/pass'

        assert form in ['DER', 'PEM', 'SMIME'], "Illegal format used."
        assert len(data) < (1024**3), "Input length limit exceeded."
        infile = tmpdir + '/in.dat'
        outfile = tmpdir + '/out.dat'
        tmp_write(infile, data)

        opts = []
        if passphrase:
            tmp_write(passfile, passphrase)
            opts.extend(['-passin', 'file:'+passfile])
        tmp_write(keyfile, key)

        retcode, output = safe_syscall(
            ['openssl', 'smime',
             '-decrypt', '-binary',
             '-inform', form,
             '-inkey', keyfile,
             ] + opts +
            ['-in', infile, '-out', outfile,
             ], raisemode=True)
        decrypted = tmp_read(outfile)
    return decrypted


# Helpers

def subj_to_dict(subj):
    '''Turn a subject string into a dictionary

    >>> subj_to_dict('/C=DE/ST=NRW/L=Herford/O=perfact::ema/'
    ...     'CN=ema-devel') == {
    ...         'CN': 'ema-devel',
    ...         'C': 'DE',
    ...         'L': 'Herford',
    ...         'O': 'perfact::ema',
    ...         'ST': 'NRW'
    ...     }
    True
    '''
    subj_pairs = split_unescape(subj[1:], '/', '\\', unescape=False)

    subj_dict = dict([split_unescape(a, '=', '\\', unescape=True)
                      for a in subj_pairs])
    return subj_dict


def dict_to_subj(subj_dict, tokens=['C', 'ST', 'L', 'O', 'OU', 'CN']):
    '''Turn a dictionary into a subject string.
    >>> dict_to_subj({'CN': 'ema-devel-2014', 'C': 'DE', 'L': 'Herford'})
    '/C=DE/ST=/L=Herford/O=/OU=/CN=ema-devel-2014'
    '''
    out_subj = ''
    for key in tokens:
        value = subj_dict.get(key, '')
        out_subj += ('/' + string_escape(key, '=/', '\\') +
                     '=' + string_escape(value, '=/', '\\'))
    return out_subj


def ca_directory_check_integrity(appca_id=1, ca_path_fmt=ca_path_fmt,
                                 ca_path=None, req_files=None):
    '''
    Check the integrity of the CA directory by searching
    for the required files.

    :param appca_id: The ID of the CA to check
    :type appca_id: int, optional
    :param ca_path_fmt: The formatter to build the full ca_path from
    :type ca_path_fmt: str, optional
    :param ca_path: The full path of the CA directory. If provided,
        The ca_path_fmt argument will be ignored
    :type ca_path: str, optional
    :param req_files: The required files which must exist inside the
        CA directory
    :type req_files: tuple, list, optional
    :raises AssertionError: Raise AssertionError if the ca_path could not
        be found
    :return: Returns 'True' if the CA directory
        is intact and 'False' otherwise
    :rtype: bool
    '''
    if ca_path is None:
        ca_path = ca_path_fmt % int(appca_id)

    if not os.path.isdir(ca_path):
        raise AssertionError(
            'CA directory: {} could not be found'.format(ca_path)
        )

    found_files = os.listdir(ca_path)

    if req_files is None:
        req_files = ['index.txt', 'serial']
        # Special handling of index.txt.attr: this one is only required if at
        # least one certificate has already been signed. If we have no ".new"
        # file for it, a massing index.txt.attr does not indicate missing
        # integrity.
        if 'index.txt.attr.new' in found_files:
            req_files.append('index.txt.attr')

    # Return 'True' if all required files are found inside the
    # CA directory
    return all([reqf in found_files for reqf in req_files])


def ca_directory_repair(appca_id=1, ca_path_fmt=ca_path_fmt,
                        ca_path=None, req_files=None,
                        flock_timeout=5):
    '''
    Try to repair the CA directory if it is in an inconsistent state.

    :param appca_id: The ID of the CA to check
    :type appca_id: int, optional
    :param ca_path_fmt: The formatter to build the full ca_path from
    :type ca_path_fmt: str, optional
    :param ca_path: The full path of the CA directory. If provided,
        The ca_path_fmt argument will be ignored
    :type ca_path: str, optional
    :param req_files: The required files which must exist inside the
        CA directory
    :type req_files: tuple, list, optional
    :param flock_timeout: The amount of time to wait for an flock request
        in seconds
    :type flock_timeout: int, optional
    :raises AssertionError: Raise AssertionError if the ca_path could not
        be found or when a repair action encountered an error
    :return: Returns 'True' if a repair action was performed successfully
        and 'False' if no repair action was performed
    :rtype: bool
    '''

    # Determines if a repair action was executed
    repaired = False

    if ca_path is None:
        ca_path = ca_path_fmt % int(appca_id)

    if not os.path.isdir(ca_path):
        raise AssertionError(
            'CA directory: {} could not be found'.format(ca_path)
        )

    lockfile = os.path.join(ca_path, openssl_lockfile)
    lock_cmds = ['flock', '-E', '10', '--wait', str(flock_timeout), lockfile]

    # Perform actions to repair the CA directory here

    # 1: Search for the required files with the '.new' extension.
    # Those files can be left over if the process which writes them suddenly
    # stops. If we find one we have to rename it to the original again.
    # Get name and fullpath of required files
    if req_files is None:
        req_files = ['index.txt', 'index.txt.attr', 'serial']
    req_files_full = [os.path.join(ca_path, rfile) for rfile in req_files]

    for rfile in req_files_full:
        broken_file = rfile + '.new'
        if not os.path.isfile(broken_file):
            continue

        # Do nothing if the required file exists alongside the '.new' one
        if os.path.isfile(rfile):
            continue

        # The required file is not present. Rename the '.new' file
        # to the required file (Removing the '.new' ending).
        retcode, _ = safe_syscall(
            lock_cmds + ['mv', broken_file, rfile],
            raisemode=False
        )

        # We ignore the returncode 1 because it means the file was
        # already moved and is no longer existing.
        if retcode not in (0, 1):
            raise AssertionError(
                'Error while repairing the CA directory: {}'.format(
                    appca_id
                )
            )

        # A repair action was successfully performed
        repaired = True

    return repaired


def test_all():
    '''Basic testing sequence. Generate new CA with and without
    passphrase.
    Sign all kinds of requests.

    Pack certs in p12 files and see if they are correct.

    Extend the CA certificate.

    Test for problems with locking of CA database files which lead to bugs
    in the past.

    Note: Test results are written to a /tmp directory and deleted afterwards
    '''
    # GENERAL settings
    ca_tests_base = '/tmp/python-modules_cert_TESTS'
    global ca_path_fmt
    ca_path_fmt = ca_tests_base + '/%d'

    # build two CAs with same subject
    ca_pass = 'ldskjf;f93579&/k3n4j'
    ca_subj = ("/C=DE/ST=NRW/L=Herford/O=perfact::ema"
               "/CN=PerFact-Certificate-Authority")
    appca_id1 = 1234
    appca_id2 = 5678
    ca_initialize(ca_subj, appca_id=appca_id1)
    ca_initialize(ca_subj, passphrase=ca_pass, appca_id=appca_id2)

    cert_pass = 'lkwk45j29d80vn2'
    cert_pass2 = '@Nc+(;!o=El5~[XZ'
    p12_pass = 'nfdiiw94587gjc__Ksjf874'

    for req_type in ['server', 'user', 'sign', 'ovpnclient']:
        print('--- Starting tests for request mode: "' + req_type + '" ---')

        key2 = cert_makekey(passphrase=cert_pass2)

        key = cert_changepassphrase(key2, cert_pass2, cert_pass)

        keyp12 = ca_makep12(key, '',
                            in_passphrase=cert_pass, out_passphrase=cert_pass,
                            appca_id=appca_id2)
        print('--- PKCS12 unpacking output ---')
        print(cert_totext(keyp12, 'pkcs12', passphrase=cert_pass))

        subj = '/C=DE/ST=NRW/L=Herford/O=perfact::ema/CN=test-%s' % req_type
        print(cert_totext(key, passphrase=cert_pass, fmt='rsa'))

        try:
            cert_totext(key, passphrase='wrong', fmt='rsa')
            assert False, "cert_totext should have failed."
        except AssertionError:
            t, e, tb = sys.exc_info()
            assert str(e).startswith('return code 1 [[unable to load'), \
                "cert_totext should have failed."

        pub = cert_getpub(key, passphrase=cert_pass)
        print('--- Public key ---')
        print(cert_totext(pub, fmt='rsapub'))

        alt_emails = []
        if req_type == 'user':
            alt_emails = ['email1@nowhere.com', 'email2@nowhere.com']

        alt_dns = []
        if req_type == 'server':
            alt_dns = ['www.perfact.de', 'example.com', 'example.perfact.de']

        csr = cert_makecsr(key, subj, request_type=req_type,
                           passphrase=cert_pass, alt_emails=alt_emails,
                           alt_dns=alt_dns)
        print('--- Certificate signing request (CSR) with SAN records ---')
        print(cert_totext(csr, 'req'))

        cert = ca_signreq(csr, passphrase=ca_pass, appca_id=appca_id2)
        print('--- Certificate with SAN records ---')
        print(cert_totext(cert, 'x509'))
        print(cert_fingerprint(cert))

        # Different start/expire timestamps
        # Datetime format: %y%m%d%H%M%SZ
        start = '20170819222305Z'
        expire = '20180123000000Z'
        cert = ca_signreq(csr, passphrase=ca_pass, appca_id=appca_id2,
                          start=start, expire=expire)
        cert_text = cert_totext(cert, 'x509')
        print('--- Certificate with SAN records  and altered start date ---')
        print(cert_text)
        start_found = cert_text.find('Not Before: Aug 19 22:23:05 2017')
        expire_found = cert_text.find('Not After : Jan 23 00:00:00 2018')
        assert start_found > 0, 'Start date of certificate does not match!'
        assert expire_found > 0, 'Expire date of certificate does not match!'

        # This fails if unique_subject is set to "yes"
        cert2 = ca_signreq(csr, passphrase=ca_pass, appca_id=appca_id2)
        print('--- Certificate with duplicate CN ---')
        print(cert_totext(cert2, 'x509'))

        # Use cert to encrypt and decrypt
        data = b'aljfkashiuHIASHDKAjehrkj2bqi7fyczsd87vz'
        for form in ['DER', 'PEM', 'SMIME']:
            encr = cert_encryptdata(data, cert, form=form)
            assert isinstance(encr, bytes)
            print("encrypted: " + str([encr[:40], len(encr)]))
            decr = cert_decryptdata(encr, key, passphrase=cert_pass, form=form)
            assert data == decr
            print("decrypted: " + repr(decr))

        p12 = ca_makep12(
            key, cert, in_passphrase=cert_pass, out_passphrase=p12_pass,
            appca_id=appca_id2)
        print(cert_totext(p12, 'pkcs12', passphrase=p12_pass))

        # Revoke certificate
        ca_revokecert(cert, passphrase=ca_pass, appca_id=appca_id2)

    cacert = ca_getcert(appca_id=appca_id2)
    print('--- CA1 certificate ---')
    print(cert_totext(cacert, 'x509'))
    ca_extend(appca_id=appca_id2, passphrase=ca_pass)
    cacert = ca_getcert(appca_id=appca_id2)
    print('--- CA2 certificate ---')
    print(cert_totext(cacert, 'x509'))

    # BEGIN Test for locking effects
    import threading
    request_type = 'ovpnclient'
    num_threads = 10
    num_csrs = 5
    csrs = []
    keys = []
    signed = []
    # Generate keys and csrs
    for i in range(num_threads * num_csrs):
        subj = '/C=DE/ST=NRW/L=Herford/O=perfact::ema/CN=test-%d' % i
        key = cert_makekey(rsabits=1024)
        csr = cert_makecsr(key=key, subj=subj,
                           request_type=request_type)
        keys.append(key)
        csrs.append(csr)
    print("%d signing requests made." % len(csrs))

    def signer_thread(start, stop):
        print('Starting thread for signing reqs %d-%d' % (start, stop))
        for csr in csrs[start:stop]:
            new_cert = ca_signreq(
                csr=csr, passphrase=ca_pass, appca_id=appca_id2)
            signed.append(new_cert)
        print('Thread finished for signing reqs %d-%d' % (start, stop))
    # Build threads
    threads = []
    for thread_index in range(num_threads):
        start = num_csrs * thread_index
        stop = num_csrs * (thread_index+1)
        t = threading.Thread(target=signer_thread, args=(start, stop,))
        threads.append(t)
    # Start all threads almost simultaneously
    for t in threads:
        t.start()
    # Wait until all are home
    for t in threads:
        t.join()
    # All threads back home, but were they successful?
    assert len(signed) == len(csrs), \
        ('Signed amount does not match. Signed: %s, Keys: %s' %
         (len(signed), len(csrs)))
    # END Test for locking effects

    print('All OK')
    print('Cleaning up...')
    shutil.rmtree(ca_tests_base)
    print('...Done!')
