import copy
import os
import yaml

# try using libYAML
try:
    from yaml import CSafeLoader as Loader
except ImportError:
    from yaml import SafeLoader as Loader

recursion_marker = b'__recursion_marker__'


def load_yaml_file(path):
    '''Convert a YAML file into a data structure.'''
    try:
        b = open(path, 'rb').read()
    except IOError as e:
        raise Exception("file '{}' could not be read".format(path)) from e

    try:
        data = yaml.load(b, Loader=Loader)
    except yaml.YAMLError as e:
        raise Exception("Syntax error in yaml file '{}'".format(path)) from e

    return data


def load_dir(directory, defaults={}, transforms={}, flatten=False):
    '''Loads YAML objects from files in a given directory into a dict by
    filename. Directories are loaded recursively. File names beginning with two
    underscores are contain meta information.'''
    entries = {}

    files = []
    meta = []

    for name in os.listdir(directory):
        # exclude swap-files from editors
        if name.endswith(".swp"):
            continue

        if name.startswith("__"):
            meta.append(name)
        else:
            files.append(name)

    if '__defaults' in meta:
        f = os.path.join(directory, '__defaults')

        data = load_yaml_file(f) or {}

        defaults = {**defaults, **data}

    if '__transforms' in meta:
        f = os.path.join(directory, '__transforms')

        parent_transforms = transforms

        transforms = {}
        exec(open(f).read(), transforms)

        for name, pt in parent_transforms.items():
            if name == '__builtins__':
                continue

            if name in transforms:
                t = transforms[name]
                # no curry. curry makes me hungry
                transforms[name] = merge_transforms(t, pt)
            else:
                transforms[name] = pt

    for name in files:
        f = os.path.join(directory, name)

        if os.path.isdir(f):
            data = load_dir(f, defaults=defaults, transforms=transforms)
            data[recursion_marker] = True
        else:
            if os.path.getsize(f) == 0:
                data = {}
            else:
                data = load_yaml_file(f) or {}

            data = {**copy.deepcopy(defaults), **data}

            if 'on_load' in transforms:
                name, data = transforms['on_load'](name, data)

        entries[name] = data

    if flatten:
        entries = globals()['flatten'](entries)

    return entries


def flatten(data: dict) -> dict:
    result = {}

    for key, value in data.items():
        if recursion_marker not in value:
            if key in result:
                raise Exception("name '{}' used more than once".format(key))

            result[key] = value
            continue

        del value[recursion_marker]

        value = flatten(value)

        for name, definition in value.items():
            if name in result:
                raise Exception("name '{}' used more than once".format(name))

            result[name] = definition

    return result

def merge_transforms(t1, t2):
    def new_t(filename, data):
        filename, data = t1(filename, data)
        return t2(filename, data)

    return new_t
