import os
import time
import sys

from ipaddress import ip_address, ip_network, IPv4Network, IPv6Network

class Singleton(object):
    def __new__(cls, *args, **kw):
        if not hasattr(cls, '_instance'):
            orig = super(Singleton, cls)
            cls._instance = orig.__new__(cls, *args, **kw)
        return cls._instance


def singleton(cls, *args, **kw):
    instances = {}

    def _singleton(*args, **kw):
        if cls not in instances:
            instances[cls] = cls(*args, **kw)
        return instances[cls]
    return _singleton


@singleton
class Registry(object):
    registryPath = None
    required_keys = ["mnt-by", "source"]
    single_keys = ["source", "descr", "single", "org", "policy", "status", "cidr", "max-length", "netname", "nic-hdl", "status", "abuse-mailbox", "as-block", "as-name", "as-set", "aut-num", "compression", "country", "dir-name", "domain", "fingerpr", "inet6num", "inetnum",
                   "key-cert", "method", "mntner", "organisation", "org-name", "owners", "owner", "person", "port", "ref", "registry", "role", "route", "route6", "route-set", "schema", "tinc-address", "tinc-file", "tinc-host", "tinc-keyset", "tinc-key"]
    multi_keys = ["mnt-by", "remarks", "tech-c", "admin-c", "org", "nserver", "ds-rdata", "member-of", "abuse-mailbox", "abuse-c", "address", "country", "e-mail", "fax-no", "mbrs-by-ref", "members", "mnt-lower", "origin", "phone", "pingable", "www", "zone-c", "auth",
                  "certif", "contact", "default", "export", "geo-loc", "geoloc", "import", "key", "language", "member", "mnt-ref", "mp-default", "mp-export", "mp-group", "mp-import", "mp-members", "network-owner", "nick", "owner", "pgp-fingerprint", "primary-key", "subnet", "url"]
    required_per_type = {"as-block": ["as-block", "policy"], "as-set": ["as-set"], "aut-num": ["aut-num", "as-name"], "dns": ["domain", "nserver"], "inet6num": ["inet6num", "cidr"], "inetnum": ["inetnum", "cidr"], "key-cert": ["key-cert", "method", "owner", "fingerpr", "certif"], "mntner": ["mntner"],
                         "organisation": ["organisation", "org-name"], "person": ["person", "nic-hdl"], "registry": ["registry", "url"], "role": ["role", "nic-hdl"], "route": ["route", "origin"], "route6": ["route6", "origin"], "route-set": ["route-set"], "schema": ["schema", "ref", "key"], "tinc-key": ["tinc-key", "tinc-host", "tinc-file"], "tinc-keyset": ["tinc-keyset", "member"]}

    def __init__(self, registryPath: str = None):
        if registryPath != None:
            self.registryPath = registryPath

        self.index = {}
        # cache: to not need to make expensive actions, expired: if something changed in the index, will rerun expensive actions
        self._cache = {"mntner-objects": {}, "expired": True}

    def _parse_from_content(self, objectType: str, objectFile: str):
        previous_key = None
        for line in self.index[objectType][objectFile]["_content"]:
            # start = " "*20 or just"+" -> continuation of previous key/value
            if line.startswith("                    ") or line == "+\n":
                if previous_key:
                    if previous_key in self.multi_keys:
                        self.index[objectType][objectFile][previous_key][-1] += "\n" + \
                            line[20:].rstrip()
                    else:
                        self.index[objectType][objectFile][previous_key] += "\n" + \
                            line[20:].rstrip()
                else:
                    print(
                        f"ERROR: empty/invalid first line(s) in {objectType}/{objectFile}")
            else:
                # if line.startswith("source"):
                #    print(f"INFO: source found in: {dir}/{objectPath}")
                _key = line.split(":")[0]
                if _key == "source":
                    pass
                if _key in self.multi_keys:
                    if not _key in self.index[objectType][objectFile]:
                        self.index[objectType][objectFile][_key] = [
                            line[20:].rstrip()]
                    else:
                        self.index[objectType][objectFile][_key].append(
                            line[20:].rstrip())
                elif _key in self.single_keys:
                    if not _key in self.index[objectType][objectFile]:
                        self.index[objectType][objectFile][_key] = line[20:].rstrip(
                        )
                    else:
                        print(
                            f"WARN: {objectType}/{objectFile} has multiple {_key}, which is has to be 'single'")
                else:
                    print(
                        f"WARN: invalid key {_key} found in {objectType}/{objectFile}")
                previous_key = _key

        for req_key in self.required_keys + self.required_per_type[objectType]:
            if not req_key in self.index[objectType][objectFile]:
                print(
                    f"WARN: required key {req_key} not found in {objectType}/{objectFile}")

    def _build_index(self, object: tuple = None):
        # fail when registryPath isn't initialized yet
        assert self.registryPath != None, f"registryPath not yet initialized"
        # load everything
        if not object:
            print("INFO: building full registry index")
            start_time = time.time()
            # TODO: get list of types from data/schema/* or data/*
            for dir in ["as-block", "as-set", "aut-num", "dns", "inet6num", "inetnum", "key-cert", "mntner", "organisation", "person", "registry", "role", "route", "route6", "route-set", "schema", "tinc-key", "tinc-keyset"]:
                self.index[dir] = {}
                try:
                    for objectFile in os.listdir(f"{self.registryPath}/data/{dir}/"):
                        self.index[dir][objectFile] = {}
                        with open(f"{self.registryPath}/data/{dir}/{objectFile}") as f:
                            self.index[dir][objectFile]["_content"] = f.readlines()
                        self._parse_from_content(dir, objectFile)
                except FileNotFoundError:
                    print(f"WARN: directory for {dir} doesn't exist")
            self._cache["expired"] = True
            print(
                f"INFO: building registry index done, took {time.time() - start_time}")

        # (re)load one specific object
        else:
            ...
            self._cache["expired"] = True

    def get_object(self, objectType: str, objectFile: str) -> tuple[bool, dict or str]:
        # check if index is not yet initialized
        if self.index == {}:
            print("INFO: requested get_object, but index not yet created")
            self._build_index()
        if objectType in self.index:
            if objectFile in self.index[objectType]:
                return True, self.index[objectType][objectFile]
        return False, "not found"

    def get_all_by_mntner(self, mntner: str, requestedType:str = None) -> tuple[bool, dict or str]:
        "requestedType: string of type, if None returns all"
        def _load_by_mntner(mntner: str) -> dict:
            ret = {}
            for objectType in self.index:
                if requestedType == None or objectType == requestedType:
                    for objectFile in self.index[objectType]:
                        if mntner in self.index[objectType][objectFile]["mnt-by"]:
                            if objectType in ret:
                                ret[objectType][objectFile] = self.index[objectType][objectFile]
                            else:
                                ret[objectType] = {
                                    objectFile: self.index[objectType][objectFile]}
            return ret

        # check if index is not yet initialized
        if self.index == {}:
            print("INFO: requested get_object, but index not yet created")
            self._build_index()

        if mntner in self._cache["mntner-objects"] and ("None" in self._cache["mntner-objects"][mntner].keys() or requestedType in self._cache["mntner-objects"][mntner].keys()):
            if self._cache["expired"]:
                # if the cache is expired: clear cache and reset "expired"
                self._cache["mntner-objects"] = {}
                self._cache["expired"] = False
                ret = _load_by_mntner(mntner)
                self._cache["mntner-objects"][mntner] = {}
                self._cache["mntner-objects"][mntner]["None" if requestedType == None else requestedType] = ret
                if ret == {}:
                    return False, f"no objects found for {mntner}"
                return True, ret

            else:
                return True, self._cache["mntner-objects"][mntner]["None" if requestedType == None else requestedType]
        else:
            if requestedType is not None and mntner in self._cache["mntner-objects"] and "None" in self._cache["mntner-objects"][mntner]:
                ret = {requestedType: self._cache["mntner-objects"]}
            else:
                ret = _load_by_mntner(mntner)
            self._cache["mntner-objects"][mntner] = ret
            if ret == {}:
                return False, f"no objects found for {mntner}"
            return True, ret

    def _save_object_to_file(self, objectType: str, objectFile: str):
        with open(f"{self.registryPath}/data/{objectType}/{objectFile}", "w") as f:
            f.writelines(self.index[objectType][objectFile]["_content"])

    def store_object(self, objectType: str, objectFile: str, content: iter):
        if objectType in self.index:
            if objectFile in self.index[objectType]:
                if type(content) == str:
                    self.index[objectType][objectFile]["_content"] = [
                        f"{line}\n" for line in content.split("\n")]
                elif type(content) in [list, tuple, iter]:
                    for line in content:
                        if type(line) != str:
                            raise ValueError(
                                f"content is {type(content)} instead of str or list of str")
                    self.index[objectType][objectFile]["_content"] = content
                else:
                    raise ValueError(
                        f"content is {type(content)} instead of str or list of str")
            else:
                raise KeyError(
                    f"type {objectType}/{objectFile} doesn't exist in index")
        else:
            raise KeyError(f"type {objectType} doesn't exist in index")

        self._save_object_to_file(objectType, objectFile)
        self._cache["expired"] = True
        self._cache["mntner-objects"] = {}

    def _build_records(self, parent_zone:str, record_name:str, TTL:int, nservers:[str], ds_rdata:[str] = []) -> [str]:
        records = []
        servers = dict()
        for nserver in nservers:
            server = nserver.split("\t", 1) if "\t" in nserver else nserver.split(" ", 1)
            if server[0] not in servers:
                servers[server[0]] = []
            if len(server) == 2:
                servers[server[0]].append(server[1])
        for server in servers:
            records.append(f"{record_name}. {TTL} IN NS {server}.")
            if not server.endswith(parent_zone):
                # nserver outside of your zone (also shouldn't have ip addresses, but who knows
                continue
            elif not server.endswith(record_name):
                # nserver address is not in this zone, won't add A/AAAA records for it
                continue
            for ip in servers[server]:
                # if there is a ip specified for this nserver
                try:
                    # try parsing the ip to an ip_address
                    # (it has to be stripped, because sometimes there are multiple whitespace between nserver hostname and ip)
                    address = ip_address(ip.strip())
                    if address.version == 6:
                        #records.append(f"{server}. {TTL} IN AAAA {address.compressed}")
                        records.append(f"{server}. {TTL} IN AAAA {ip.strip()}") # the java implementation of the dn42 master just copies the (strriped) ip ...
                    elif address.version == 4:
                        records.append(f"{server}. {TTL} IN A {address}")
                    else:
                        print(f"WARN: unknown ip version of '{ip}' for {server}")
                except ValueError:
                    print(f"WARN: '{ip}' for {server} isn't a a valid ip address")

        for ds in ds_rdata:
            records.append(f"{record_name}. {TTL} IN DS {ds}")

        return records

    def _build_registry_sync_zone(self, zone:str, TTL:int) -> [str]:
        # returns A/AAAA records for nservers of the form "$(reverse_ipv4).ipv4.${zone}" and "${reverse_ipv6}.ipv6.${zone}" as well as the records based on data/dns/${zone}

        zone = zone[:-1] if zone.endswith(".") else zone

        if not zone in self.index["dns"]:
            print(f"ERROR: object for dns/{zone} doesn't exist not generating a empty zone for it")
            return []

        domain_data = self.index["dns"][zone]

        records = self._build_records(zone, zone, TTL, domain_data["nserver"], domain_data["ds-rdata"] if "ds-rdata" in domain_data else [])

        v4_domain = f".ipv4.{zone}"
        v6_domain = f".ipv6.{zone}"

        joined = self.index["dns"] | self.index["inet6num"] | self.index["inetnum"]

        for key in joined:

            object_data = joined[key]

            if not "nserver" in object_data:
                continue

            for nserver in object_data["nserver"]:

                nserver = nserver.split(" ", 1)
                if not nserver[0].endswith(zone):
                    continue

                elif not len(nserver) == 1:
                    print(f"WARN: registry sync: {key} specifies ip address for a registry-sync address, ignoring that address")


                if nserver[0].endswith(v4_domain):
                    records.append(f"{nserver[0]}. {TTL} IN A {'.'.join(nserver[0].replace(v4_domain, '').split('.')[::-1])}")
                elif nserver[0].endswith(v6_domain):
                    _ip6 = nserver[0].replace(v6_domain, "").replace(".", "")[::-1]
                    try:
                        records.append(f"{nserver[0]}. {TTL} IN AAAA {ip_address(':'.join(a+b+c+d for a, b, c, d in zip(_ip6[::4], _ip6[1::4], _ip6[2::4], _ip6[3::4]))).compressed}")
                    except ValueError:
                        print(f"WARN: {nserver[0]} couldn't get parsed to ipv6 address, not adding it to the zone")
                else:
                    print(f"WARN: unknown registry-sync prefix in {key} not parsing that hostname")
        return records

    def _generate_forward_zone(self, zone:str, TTL:int) -> [str]:
        records = []
        zone = zone[:-1] if zone.endswith(".") else zone

        for domain in self.index["dns"]:
            if not domain.endswith(zone):
                # ignore none $zone domains
                continue
            domain_data = self.index["dns"][domain]
            records += self._build_records(zone, domain, TTL, domain_data["nserver"], domain_data["ds-rdata"] if "ds-rdata" in domain_data else [])

        return records

    def _generate_reverseV6_zone(self, zone:str, TTL:int) -> [str]:
        records = []
        zone = zone[:-1] if zone.endswith(".") else zone

        for objectFile in self.index["inet6num"]:
            net = IPv6Network(objectFile.replace("_", "/"))
            # generate domain from the network
            domain = ".".join(net.exploded.split("/")[0].replace(":", "")[(net.prefixlen//4)-1::-1]) + ".ip6.arpa"
            if not domain.endswith(zone):
                # ignore none $zone domains
                continue
            domain_data = self.index["inet6num"][objectFile]
            # ignore inet6nums without nservers
            if not "nserver" in domain_data:
                continue
            records += self._build_records(zone, domain, TTL, domain_data["nserver"], domain_data["ds-rdata"] if "ds-rdata" in domain_data else [])

        return records

    def _generate_reverseV4_zone(self, zone:str, TTL:int) -> [str]:
        records = []
        zone = zone[:-1] if zone.endswith(".") else zone

        for objectFile in self.index["inetnum"]:
            net = IPv4Network(objectFile.replace("_", "/"))
            if net.prefixlen > 24:
                domain = net.reverse_pointer
                if not domain.endswith(zone):
                    # ignore none $zone domains
                    continue
                domain_data = self.index["inetnum"][objectFile]
                # ignore inetnums without nservers
                if not "nserver" in domain_data:
                    continue

                records += self._build_records(zone, domain, TTL, domain_data["nserver"], domain_data["ds-rdata"] if "ds-rdata" in domain_data else [])

                # generate the CNAMEs for the single ips (because we don't have a full /24)
                if net.prefixlen == 32:
                    records.append(f"{net.network_address.reverse_pointer}. {TTL} IN CNAME {net.network_address.reverse_pointer.split('.',1)[0]}.{domain}.")
                else:
                    records += [f"{host.reverse_pointer}. {TTL} IN CNAME {host.reverse_pointer.split('.',1)[0]}.{domain}." for host in [net.network_address, *net.hosts(), net.broadcast_address]]

            elif net.prefixlen % 8 == 0:
                # this is a /8, /16, or /24 (/32s are handled above)
                net = IPv4Network(objectFile.replace("_", "/"))
                domain = ".".join(net.reverse_pointer.split(".")[(4-net.prefixlen//8):])

                if not domain.endswith(zone):
                    # ignore none $zone domains
                    continue
                domain_data = self.index["inetnum"][objectFile]
                # ignore inetnums without nservers
                if not "nserver" in domain_data:
                    continue

                records += self._build_records(zone, domain, TTL, domain_data["nserver"], domain_data["ds-rdata"] if "ds-rdata" in domain_data else [])

            else:
                # we now only have larger than /24 (but not "whole" subnets) remaining => multiple /24 zones
                net = IPv4Network(objectFile.replace("_", "/"))
                domain = ".".join(net.reverse_pointer.split(".")[(3-net.prefixlen//8):])

                if not domain.endswith(zone):
                    # ignore none $zone domains
                    continue
                domain_data = self.index["inetnum"][objectFile]
                # ignore inetnums without nservers
                if not "nserver" in domain_data:
                    continue
                for subnet in net.subnets(8-(net.prefixlen % 8)):
                    domain = ".".join(subnet.reverse_pointer.split(".")[(3-net.prefixlen//8):])
                    records += self._build_records(zone, domain, TTL, domain_data["nserver"], domain_data["ds-rdata"] if "ds-rdata" in domain_data else [])

        return records

    def generate_dns_zone(self, zone:str, TTL:int=900) -> [str]:
        # check if index is not yet initialized
        if self.index == {}:
            print("INFO: requested generate_dns_zone, but index not yet created")
            self._build_index()

        if zone.endswith("ip6.arpa."):
            # ipv6 reverse zone -> inet6num
            return list(set(self._generate_reverseV6_zone(zone, TTL)))

        elif zone.endswith("in-addr.arpa."):
            # ipv4 reverse zone -> inetnum
            return list(set(self._generate_reverseV4_zone(zone, TTL)))
        else:
            # other zone -> dns
            return list(set(self._generate_forward_zone(zone, TTL)))


if __name__ == "__main__":
    reg = Registry("dn42-registry")
    reg._build_index()
    print(reg.get_all_by_mntner("LARE-MNT"))