#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# tcpserver.py  -  TCP-Server as dummy endpoint for testing
#
# Copyright (C) 2017 Jens Hinghaus <jens.hinghaus@perfact.de>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
#

import threading
import six
import ssl
import os
from .. import generic

if six.PY2:
    from SocketServer import BaseRequestHandler, TCPServer
    from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler
    from urlparse import urlparse, parse_qs
else:
    from socketserver import BaseRequestHandler, TCPServer
    from http.server import HTTPServer, BaseHTTPRequestHandler
    from urllib.parse import urlparse, parse_qs

TCPServer.allow_reuse_address = True

# we need to change the working directory to access our certificate
# and key files for the TCP-Server
module_dir = os.path.dirname(os.path.realpath(__file__))
os.chdir(module_dir)
# print('Now in dir: %s' % os.getcwd())

# id_token contains a JWT with example values
OP_BACKCHANNEL_RESPONSE = '''
HTTP/1.1 200 OK
Content-Type: application/json
Cache-Control: no-store
Pragma: no-cache

{
  "id_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJub25jZSI6InhqczhZdTAtM\
iIsInN1YiI6InRlc3R1c2VyIiwiaXNzIjoiaHR0cHM6Ly9zZXJ2ZXIuZXhhbXBsZS5jb20iLCJleH\
AiOjIxNDczODIwMDAsImlhdCI6MTUwNzIxMjMyOCwiYXVkIjoic29tZS12YWx1ZSJ9.fCZwXdLb6e\
sU-7YE5y6zqk94Vc167LWuoACOYFK7HV8",
  "access_token": "SlAV32hkKG",
  "token_type": "Bearer",
  "expires_in": 3600
}
'''


class MockHandler(BaseHTTPRequestHandler):
    '''
    Simple HTTPS request handler for testing purposes
    Extend this class to so it fits your needs e.g. you could implement a
    do_GET function to handle GET requests...

    The class variables below should be used to configure the server for the
    next request.
    It checks if the given URLs and parameters (key/value) match and responds
    with a preconfigured answer. If no answer is configured, the parameters
    will just be echoed back to the caller.
    '''
    # some variables which should be overloaded
    # these varialbes can later be used as self.xxx
    expected_urls = []
    expected_params = {}
    answer = None
    expected_headers = {
        'host': None,
        'accept-encoding': None,
        'accept': None,
        'user-agent': None,
        'content-length': None,
        'content-type': None
    }

    def parse_common(self):
        # parse params from get request
        self.query = urlparse(self.path).query
        self.received_url = urlparse(self.path).path
        self.headers = self.headers or []

    def debug_info(self):
        print('HTTP server received path: %s' % str(self.received_url))
        print('HTTP server received headers: %s' % str(self.headers))
        # params_received is a dictionary
        print('HTTP server received parameters: %s'
              % str(self.received_params))

    def compare_input(self):
        # workaround to get all lowercase keys in the dict
        for header in self.expected_headers.keys():
            origvalue = self.expected_headers[header]
            del self.expected_headers[header]
            self.expected_headers[header.lower()] = origvalue

        # check received headers
        for header in self.headers:
            received_value = self.headers.get(header)
            expected_value = self.expected_headers.get(str(header))
            if expected_value:
                if received_value != expected_value:
                    msg = 'Header %s with value %s was not expected' \
                              % (header, received_value)
                    self.create_response(code=500, msg=msg)
                    return
        # iterate over received params
        msg = ''
        for key in self.received_params.keys():
            # check URLs if any were configured
            if (self.expected_urls and
                    not (self.received_url in self.expected_urls)):
                msg = 'URL %s is not accepted' % self.received_url
                self.create_response(code=500, msg=msg)
                return
            # validate params
            expected_value = self.expected_params.get(key)
            received_value = self.received_params[key]
            if expected_value == received_value:
                if self.answer:
                    # write out the preconfigured answers
                    msg = self.answer
                else:
                    # just echo back the given parameters and values
                    msg += (key + '=' + received_value)
            else:
                msg = 'Parameters did not match for Key: "%s" ' \
                      'Received: "%s" ' \
                      'Expected: "%s"' % (key,
                                          received_value,
                                          expected_value)
                self.create_response(code=500, msg=msg)
                return
        self.create_response(msg=msg)

    def do_GET(self):
        self.parse_common()
        received_params_list = parse_qs(self.query)
        # query thingy puts values into list, unpack (better way?!):
        self.received_params = dict(
            (key, value[0]) for key, value in received_params_list.items()
        )
        self.debug_info()
        self.compare_input()
        return

    def do_POST(self):
        '''
        Respond to a POST request.
        '''
        self.parse_common()
        length = int(self.headers['Content-Length'])
        received_params_list = parse_qs(
            self.rfile.read(length).decode('utf-8')
        )
        self.received_params = dict(
            (key, value[0]) for key, value in received_params_list.items()
        )
        self.debug_info()
        self.compare_input()
        return

    def create_response(self, msg=None, code=200, content_type='text/html'):
        self.send_response(code)
        self.send_header("Content-type", content_type)
        self.end_headers()
        self.wfile.write(generic.to_bytes(msg))
        return


class StaticResponder(BaseRequestHandler):
    ''' RequestHandler responding with static content '''

    def handle(self):
        # self.request is the TCP socket connected to the client
        self.data = self.request.recv(2024)
        # just send back static data
        self.request.sendall(self.server.static_content)


def startStaticServer(port, content, host='127.0.0.1'):
    ''' Start a TCP server repsonding with the content you gave it '''
    server = TCPServer((host, port), StaticResponder)
    server.static_content = content
    server_thread = threading.Thread(target=server.serve_forever)
    server_thread.start()
    return server, server_thread


def stopServer(server, server_thread):
    ''' Stop a static or dynamic server thread '''
    server.shutdown()
    server.server_close()
    server_thread.join()


def startDynamicServer(host='127.0.0.1', port=4443, handler=None,
                       use_ssl=True):
    ''' Start a Basic HTTP server on given host & port with optional ssl
    support. A handler should be given to serve your requests.
    '''
    if not handler:
        handler = MockHandler

    server = HTTPServer(
        (host, port),
        handler
    )
    if use_ssl:
        server.socket = ssl.wrap_socket(
            server.socket,
            certfile='../assets/tests/tcpserver_cert.pem',
            keyfile='../assets/tests/tcpserver_key.pem',
            server_side=True
        )
    server_thread = threading.Thread(target=server.serve_forever)
    server_thread.start()
    return server, server_thread


if __name__ == '__main__':
    # debug: run server on console.
    try:
        dynServer, dynServer_thread = startDynamicServer()
        print('Server started!')
        while True:
            pass
    except KeyboardInterrupt:
        print('Interrupted! Shutting down...')
        stopServer(dynServer, dynServer_thread)
