#!/usr/bin/env python3

REGISTRY_PATH = "../registry"

import hashlib
import base64
import struct
import sys
import os
import subprocess
import json

try:
    import dns.name
    import dns.query
    import dns.dnssec
    import dns.message
    import dns.resolver
    import dns.rdatatype
except ImportError:
    print()
    print("CRITICAL: this script requires the 'dnspython' libary, please install it using `python3 -m pip install dnspython`")
    print()
    exit(1)

# import errors so they could be try/catched
import dns.exception
import binascii

# counter of errors that occured
errors= 0

# step1:


def get_domain_by_mntner(mntner):
    global errors

    """get a list of domains (and reverse ipv4/6) if a nserver is specified"""
    # grep for the given mntner in the dns,inetnum,inet6num directory of the registry and split it into a list; replace // with / in case REGISTRY_PATH ends with /
    dns_files = subprocess.Popen(["grep", "-Ril", f" {mntner}", f"{REGISTRY_PATH}/data/dns/"],
                                 stdout=subprocess.PIPE).communicate()[0].decode().replace("//", "/").split("\n")[:-1]
    inetnums_files = subprocess.Popen(["grep", "-Ril", f" {mntner}", f"{REGISTRY_PATH}/data/inetnum/"],
                                      stdout=subprocess.PIPE).communicate()[0].decode().replace("//", "/").split("\n")[:-1]
    inet6nums_files = subprocess.Popen(
        ["grep", "-Ril", f" {mntner}", f"{REGISTRY_PATH}/data/inet6num/"], stdout=subprocess.PIPE).communicate()[0].decode().split("\n")[:-1]

    # domains dict containing dns objects and inet(6)nums if they have nserver specified
    domains = {}


    def _parse_nserver(line):
        nserver = line[20:].split(" ")
        # handle edge case where
        if "\t" in nserver[0]:
            nserver = nserver[0].split("\t")
        # ignore registry-sync nservers
        if "registry-sync.dn42" in nserver[0]:
            return

        # nserver should be defined in an other dns file
        if len(nserver) == 1:
            return [nserver[0], None]
        # nserver is defined in this file
        elif len(nserver) == 2:
            return nserver

    # read dns files
    for domain in dns_files:
        with open(domain) as d:
            domain_name = domain.split("/")[-1]
            # a dictionary for each domain with "nserver": {"ns1.domain.dn42": ["ns1 ipv4", "ns1 ipv6"], ...}, "ds-rdata": ["123 45 67 ...", "98 7 65 ..."]
            domains[domain_name] = {"nserver": {}, "ds-rdata": []}
            for line in d.readlines():
                line = line.replace("\n", "")
                if line.startswith("nserver"):
                    _tmp = _parse_nserver(line)
                    if _tmp == "break": break
                    if _tmp[0] in domains[domain_name]["nserver"]:
                        domains[domain_name]["nserver"][_tmp[0]].append(_tmp[1])
                    else:
                        domains[domain_name]["nserver"][_tmp[0]
                        ] = [_tmp[1]]

                elif line.startswith("ds-rdata:"):
                    domains[domain_name]["ds-rdata"].append(line[20:].lower())
    # load inetnums
    for inetnum in inetnums_files:
        # temp variables in case there is no nserver
        _nserver = {}
        _ds_rdata = []
        _domain_name = ""
        with open(inetnum) as i4:
            for line in i4.readlines():
                line = line.replace("\n", "")
                if line.startswith("cidr"):
                    line = line[20:]
                    _domain_name = ".".join(
                        line.split(".")[::-1]) + ".in-addr.arpa"
                    if int(line.split("/")[1]) == 24:
                        _domain_name = _domain_name.replace("0/24.", "")
                    elif int(line.split("/")[1]) == 16:
                        _domain_name = _domain_name.replace("0/16.0.", "")
                    elif int(line.split("/")[1]) == 8:
                        _domain_name = _domain_name.replace("0/8.0.0.", "")
                    elif int(line.split("/")[1]) <=24:
                        # TODO: implement creation of multiple zones for every /24 within
                        print(f"WARN: currently only ipv4 subnets with length >=24 or 16 or 8 are possible to be checked: relavent inetnum {line}")
                elif line.startswith("nserver"):
                    _tmp = _parse_nserver(line)
                    if _tmp == "break": break
                    if _tmp[0] in _nserver:
                        _nserver[_tmp[0]].append(_tmp[1])
                    else:
                        _nserver[_tmp[0]] = _tmp[1]

                elif line.startswith("ds-rdata:"):
                    _ds_rdata.append(line[20:].lower())
        # if nserver list is not empty add the reverse to the domain list
        if not _nserver == {}:
            domains[_domain_name] = {
                "nserver": _nserver, "ds-rdata": _ds_rdata}
    # load inet6nums
    for inet6num in inet6nums_files:
        # temp variables in case there is no nserver
        _nserver = {}
        _ds_rdata = []
        _domain_name = ""
        with open(inet6num) as i6:
            for line in i6.readlines():
                line = line.replace("\n", "")
                if line.startswith("inet6num"):
                    line = line[20:]

                    # generate the reverse ipv6
                    _domain_name = "ip6.arpa"
                    _lowest, _highest = line.replace(":", "").split(" - ")
                    for _digit1, _digit2 in zip(_lowest, _highest):
                        if _digit1 != _digit2:
                            break
                        _domain_name = _digit1 + "." + _domain_name
                elif line.startswith("nserver"):
                    _tmp = _parse_nserver(line)
                    if _tmp == "break": break
                    if _tmp[0] in _nserver:
                        _nserver[_tmp[0]].append(_tmp[1])
                    else:
                        _nserver[_tmp[0]] = _tmp[1] 

                elif line.startswith("ds-rdata:"):
                    _ds_rdata.append(line[20:].lower())
        # if nserver list is not empty add the reverse to the domain list
        if not _nserver == {}:
            domains[_domain_name] = {
                "nserver": _nserver, "ds-rdata": _ds_rdata}

    # add entries from main domain, if the nserver doesn't have an ip address (like in inet(6)nums)
    for domain in domains:
        for nserver in domains[domain]["nserver"]:
            # if the nserver isn't specified: ...
            if domains[domain]["nserver"][nserver] == None:
                # print(f"INFO: the nserver {nserver} isn't specified in {domain}, looking into the parent domain of it")
                for i in range(len(nserver.split(".")), 1, -1):
                    # check if the nserver is already in loaded database, starts with more specific
                    if ".".join(nserver.split(".")[-i:]) in domains:
                        try:
                            domains[domain]["nserver"][nserver] = domains[".".join(
                                nserver.split(".")[-i:])]["nserver"][nserver]
                        except KeyError:
                            # reaches here if the domain for the nserver specified in the inet{6}num/domain is found, but the nserver itself not.
                            print(
                                f"Warn: the nserver {nserver} specified in {domain} wasn't found")
                            break

    return domains


def get_dnskey(domain_name, nserver):
    """query dns server for DNSKEY"""
    global errors
    try:
        request = dns.message.make_query(
            domain_name, dns.rdatatype.DNSKEY, want_dnssec=False)
        response = dns.query.udp_with_fallback(request, nserver, timeout=2)
    except dns.exception.Timeout:
        print(f"WARN: querying {nserver} for {domain_name} timed out")
        errors += 1
        return False
    except dns.query.UnexpectedSource as e:
        print(f"ERROR: server replied with different different ip than requested: error: {e}")
        errors += 1
        return False
    if response[0].rcode() != 0:
        # HANDLE QUERY FAILED (SERVER ERROR OR NO DNSKEY RECORD)
        print(
            f"WARN: query for a DNSKEY on {domain_name} failed on {nserver}, returncode: {response[0].rcode()}")
        errors += 1
        return False
    return [dnskey.to_text().split("IN DNSKEY ")[1] for dnskey in response[0].answer]
    # if not nserver:
    #     # drill: use DNSSEC, fallback to tcp, for $domain
    #     drill_cmd = subprocess.Popen(
    #         ["drill", "-D", "-a", "DNSKEY", domain], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    # else:
    #     # drill: use DNSSEC, fallback to tcp, ask $nserver for $domain
    #     drill_cmd = subprocess.Popen(
    #         ["drill", "-D", "-a", "DNSKEY", f"@{nserver}", domain], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    # # get the output of the drill command
    # lines = drill_cmd.communicate()[0].decode().split("\n")
    # # filter out comments (;;) and empty lines
    # new_lines = []
    # for line in lines:
    #     if line == "" or line.startswith(";;"):
    #         continue
    #     new_lines.append(line)
    # lines = new_lines
    # # split records into domain and key
    # dnskeys = [[record.split("\t")[0], record.split("\t")[-1]]
    #            for record in lines]
    # #print(len(dnskeys))
    # return dnskeys
# end_step1


# step 2: <start dnskey_to_DS.py based on https://gist.github.com/wido/4c6288b2f5ba6d16fce37dca3fc2cb4a >
"""
Generate a DNSSEC DS record based on the incoming DNSKEY record
The DNSKEY can be found using for example 'dig':
$ dig DNSKEY secure.widodh.nl
The output can then be parsed with the following code to generate a DS record
for in the parent DNS zone
Author: Wido den Hollander <wido@widodh.nl>
Many thanks to this blogpost: https://www.v13.gr/blog/?p=239
"""


def _calc_keyid(flags, protocol, algorithm, dnskey):
    st = struct.pack('!HBB', int(flags), int(protocol), int(algorithm))
    st += base64.b64decode(dnskey)

    cnt = 0
    for idx in range(len(st)):
        s = struct.unpack('B', st[idx:idx+1])[0]
        if (idx % 2) == 0:
            cnt += s << 8
        else:
            cnt += s

    return ((cnt & 0xFFFF) + (cnt >> 16)) & 0xFFFF


def _calc_ds(domain, flags, protocol, algorithm, dnskey):
    if domain.endswith('.') is False:
        domain += '.'

    signature = bytes()
    for i in domain.split('.'):
        signature += struct.pack('B', len(i)) + i.encode()

    signature += struct.pack('!HBB', int(flags), int(protocol), int(algorithm))
    signature += base64.b64decode(dnskey)

    return {
        'sha1':    hashlib.sha1(signature).hexdigest().upper(),
        'sha256':  hashlib.sha256(signature).hexdigest().upper(),
    }


def dnskey_to_ds(domain, dnskey):
    dnskeylist = dnskey.split(' ', 3)

    flags = dnskeylist[0]
    protocol = dnskeylist[1]
    algorithm = dnskeylist[2]
    key = dnskeylist[3].replace(' ', '')

    keyid = _calc_keyid(flags, protocol, algorithm, key)
    ds = _calc_ds(domain, flags, protocol, algorithm, key)

    ret = list()
    ret.append(str(keyid) + ' ' + str(algorithm) + ' ' + str(1) + ' '
               + ds['sha1'].lower())
    ret.append(str(keyid) + ' ' + str(algorithm) + ' ' + str(2) + ' '
               + ds['sha256'].lower())
    return ret

# if __name__ == "__main__":
#    print(len(sys.argv))
#    if len(sys.argv) == 1:
#        print("no data specified, please enter manually")
#        DOMAIN = input("enter domain: ")
#        DNSKEY = input("enter DNSKEY: ")
#
#        print(dnskey_to_ds(DOMAIN, DNSKEY))
#    elif len(sys.argv) in [9,10]:
#        # user pasted full dig output
#        DOMAIN = sys.argv[1]
#        DNSKEY = " ".join(sys.argv[5:])
#        print(DOMAIN)
#        print(DNSKEY)
#        print(dnskey_to_ds(DOMAIN,DNSKEY))

# step2: <end dnskey_to_DS.py>

# step3: start: partially stolen from: https://stackoverflow.com/questions/26137036/programmatically-check-if-domains-are-dnssec-protected


def check_dnssec(domain_name, domain_data):
    global errors
    success = False

    no_ds_rdata = domain_data["ds-rdata"] == []
    if no_ds_rdata:
        print(
            f"NOTE: {domain_name} doesn't have ds-rdata configured, not checking it")
        return True

    for nserver in domain_data["nserver"]:

        # if the nserver is not set (i.e. not loaded from other dns file or "wrong" fqdn)
        if domain_data["nserver"][nserver] == None:
            print(
                f"INFO: ip address(es) for nserver '{nserver}' in '{domain_name}' isn't specified/loaded")
            continue
        for nsaddr in domain_data["nserver"][nserver]:

            # get SOA
            request = dns.message.make_query(
                domain_name, dns.rdatatype.SOA, want_dnssec=False)
            try:
                # send the query
                dns.query.udp_with_fallback(request, nsaddr, timeout=2)
            # if it timed out: tell the user
            except dns.exception.Timeout:
                print(
                    f"WARN: querying {nserver} ({nsaddr}) for {domain_name} timed out")
                continue

            if no_ds_rdata:
                print(
                    f"INFO: query for {domain_name} SOA on {nserver} ({nsaddr}) succeded, not checking DNSSEC")
                continue
            # get DNSKEY for zone
            request = dns.message.make_query(
                domain_name, dns.rdatatype.DNSKEY, want_dnssec=True)
            response = dns.query.udp_with_fallback(request, nsaddr, timeout=2)

            if response[0].rcode() != 0:
                # HANDLE QUERY FAILED (SERVER ERROR OR NO DNSKEY RECORD)
                print(
                    f"WARN: query for a DNSKEY on {domain_name} failed on {nserver} ({nsaddr}), returncode: {response[0].rcode()}")
                errors += 1
                continue
            # answer should contain two RRSET: DNSKEY and RRSIG(DNSKEY)
            answer = response[0].answer
            if len(answer) != 2:
                # SOMETHING WENT WRONG
                print(
                    f"ERROR: query for a DNSKEY on {domain_name} failed on {nserver} ({nsaddr}), invalid answer length: {len(answer)}")
                errors += 1
                continue
            # the DNSKEY should be self signed, validate it
            name = dns.name.from_text(domain_name)
            try:
                # print(f"DEBUG: answer[0]: {answer[0]}")
                # print(f"DEBUG: answer[1]: {answer[1]}")
                try:
                    dns.dnssec.validate(
                        answer[0], answer[1], {name: answer[0]})
                # it raises an AttributeError if the records are in the wrong order
                except AttributeError as e:
                    dns.dnssec.validate(
                        answer[1], answer[0], {name: answer[0]})

            except dns.dnssec.ValidationFailure:
                # BE SUSPICIOUS
                print(
                    f"WARN: DNSSEC validation failed on {domain_name} failed on {nserver} ({nsaddr}), answer: {answer}")
                errors += 1
            except AttributeError as e:
                print(f"ERROR: {e}")
                errors += 1
            else:
                # WE'RE GOOD, THERE'S A VALID DNSSEC SELF-SIGNED KEY FOR example.com
                print(
                    f"INFO: DNSSEC validation succeded on {domain_name} failed on {nserver} ({nsaddr})")
                success = True

        return success


# step3: end

def main(mntner):
    global errors
    # get all domains/inet(6)nums of the mntner
    domains = get_domain_by_mntner(mntner=mntner)

    for domain_name in domains:
        
        # check if the domain doesn't have DS data
        if domains[domain_name]["ds-rdata"] == []:
            print(f"NOTE: {domain_name} doesn't have any ds-rdata specified")
            continue

        for nserver in domains[domain_name]["nserver"]:
            # check for unset nserver ips -> dont check them
            if domains[domain_name]["nserver"][nserver] == None:
                continue
            for ip in domains[domain_name]["nserver"][nserver]:
                ds_candidates = []
                # load DNSKEYs from nserver: if False something failed (timeout)
                _keys = get_dnskey(domain_name, ip)
                if _keys == False: continue
                # convert all found keys to DS
                for key in _keys:
                    try:
                        _ds_s = dnskey_to_ds(domain_name, key)
                    except binascii.Error as e:
                        print(f"ERROR: trying to convert '{key}' to DS failed: {e}")
                        continue
                    ds_candidates.extend(_ds_s)
                found = False
                # iterate over DS-rdata from the registry and check if they are found on the nserver
                for ds in domains[domain_name]["ds-rdata"]:
                    # print(ds)
                    if ds in ds_candidates:
                        found = True
                # print(f"DEBUG: available: {domains[domain_name]['ds-rdata']}")
                # print(f"DEBUG: generated: {ds_candidates}")
                if found:
                    print(
                        f"INFO: correct ds-rdata specified and matching DNSKEY returned by {ip} for {domain_name}")
                else:
                    print(
                        f"ERROR: invalid ds-rdata specified and matching DNSKEY returned by {ip} for {domain_name}")
                    errors += 1
            # break

        # print(check_dnssec(domain_name, domains[domain_name]))


if __name__ == "__main__":
    if len(sys.argv) == 1:
        print(f"please specify your mntner\n   {sys.argv[0]} YOU-MNT")
        exit(1)
    main(sys.argv[1])
    exit(errors)


# commands to run:
# 1. drill -D <domain>.dn42 @ns1.<domain>.dn42 NS
# 2. dnskey_to_ds("<domain>.dn42"
# #<TTL> IN DNSKEY
# "257 3 13 <base64 ...>")
# 3. write dnskey to "trust-anchor"
# 4.delv @ns1.<domain.dn42 +root=<domain>.dn42 -a ./trust-anchor.tmp SOA <domain>.dn42