#!/usr/bin/python3

import os
import sys
import socket
import struct
import fnmatch
import optparse

os.environ.setdefault('CONDOR_CONFIG', '/etc/condor-ce/condor_config')
import htcondor2 as htcondor
from htcondorce.tools import to_str

from socket import AF_INET, AF_INET6, inet_ntop
from ctypes import (
    Structure, Union, POINTER,
    pointer, get_errno, cast,
    c_ushort, c_byte, c_void_p, c_char_p, c_uint, c_uint16, c_uint32
)
import ctypes.util
import ctypes

g_verbose = True

IPV4_PRIVATE_NETWORKS = [dict(subnet='10.0.0.0',    netmask='255.0.0.0'),
                         dict(subnet='172.16.0.0',  netmask='255.240.0.0'),
                         dict(subnet='192.168.0.0', netmask='255.255.0.0'),
]

IPV4_LOOPBACK_NETWORK = dict(subnet='127.0.0.0', netmask='255.0.0.0')

# Based on http://programmaticallyspeaking.com/getting-network-interfaces-in-python.html
# which is, in turn, based on Based on getifaddrs.py from pydlnadms [http://code.google.com/p/pydlnadms/].

class struct_sockaddr(Structure):
    """struct sockaddr from sys/socket.h
    See socket.h in manpage section 0p

    A generic "socket" structure; can be cast into one of the other
    struct sockaddr_* types

    Fields:
    - sa_family: identifies the type of socket (e.g. AF_INET for IPv4, AF_INET6
                 for IPv6)
    - sa_data:   actual socket data - length and contents are format dependent

    """
    _fields_ = [
        ('sa_family', c_ushort),
        ('sa_data', c_byte * 14),]


class struct_sockaddr_in(Structure):
    """struct sockaddr_in from netinet/in.h
    See in.h in manpage section 0p

    Fields:
    - sin_family: always AF_INET
    - sin_port:   port number
    - sin_addr:   IP address

    """
    _fields_ = [
        ('sin_family', c_ushort),
        ('sin_port', c_uint16),
        ('sin_addr', c_byte * 4)]


class struct_sockaddr_in6(Structure):
    """struct sockaddr_in6 from netinet/in.h
    See in.h in manpage section 0p

    Fields:
    - sin6_family:   always AF_INET6
    - sin6_port:     port number
    - sin6_flowinfo: IPv6 traffic class and flow information
    - sin6_addr:     IPv6 address
    - sin6_scope_id: set of interfaces for a scope

    """
    _fields_ = [
        ('sin6_family', c_ushort),
        ('sin6_port', c_uint16),
        ('sin6_flowinfo', c_uint32),
        ('sin6_addr', c_byte * 16),
        ('sin6_scope_id', c_uint32)]


class union_ifa_ifu(Union):
    """Anonymous union used for field ifa_ifu in struct ifaddrs from ifaddrs.h
    See getifaddrs in manpage section 3

    """
    _fields_ = [
        ('ifu_broadaddr', POINTER(struct_sockaddr)),
        ('ifu_dstaddr', POINTER(struct_sockaddr)),]


class struct_ifaddrs(Structure):
    """struct ifaddrs from ifaddrs.h
    See getifaddrs in manpage section 3

    A linked list; each element describes one network interface.  The ifa_next
    field is a pointer to the next entry in the list.

    """
# _fields_ set separately: struct_ifaddrs needs to exist so it can contain a pointer to another struct_ifaddrs
struct_ifaddrs._fields_ = [
    ('ifa_next', POINTER(struct_ifaddrs)),
    ('ifa_name', c_char_p),
    ('ifa_flags', c_uint),
    ('ifa_addr', POINTER(struct_sockaddr)),
    ('ifa_netmask', POINTER(struct_sockaddr)),
    ('ifa_ifu', union_ifa_ifu),
    ('ifa_data', c_void_p),]


libc = ctypes.CDLL(ctypes.util.find_library('c'))


def ifap_iter(ifap):
    """Iterate over a pointer to a struct ifaddrs and yield the contents of the
    structure.

    Params:
    - ifap: pointer(struct_ifaddrs)
    Yields:
    - struct_ifaddrs

    """
    ifa = ifap.contents
    while True:
        yield ifa
        if not ifa.ifa_next:
            break
        ifa = ifa.ifa_next.contents


def getfamaddr(sa):
    """Extract the address family and address from a struct_sockaddr.

    Params:
    - sa: struct_sockaddr
    Returns: (family, addr)
    - family: AF_INET, AF_INET6 or one of the other AF_* constants from the
              socket module
    - addr:   if family is AF_INET, the IPv4 address as a string
              if family is AF_INET6, the IPv6 address as a string
              otherwise, None

    """
    family = sa.sa_family
    addr = None
    if family == AF_INET:
        sa = cast(pointer(sa), POINTER(struct_sockaddr_in)).contents
        addr = inet_ntop(family, sa.sin_addr)
    elif family == AF_INET6:
        sa = cast(pointer(sa), POINTER(struct_sockaddr_in6)).contents
        addr = inet_ntop(family, sa.sin6_addr)
    return family, addr


class NetworkInterface(object):
    """The name, index, and IP addresses associated with a network interface.
    - index is taken from if_nametoindex(3).
    - addresses is a dict of addresses keyed by family (e.g. the AF_INET,
      AF_INET6 constants from the socket library)

    """
    def __init__(self, name):
        self.name = name
        self.index = libc.if_nametoindex(name)
        self.addresses = {}

    def __str__(self):
        return "%s [index=%d, IPv4=%s, IPv6=%s]" % (
            self.name, self.index,
            ",".join(self.addresses.get(AF_INET)),
            ",".join(self.addresses.get(AF_INET6)))


def get_network_interfaces(pattern):
    """Return NetworkInterface objects for each network interface present on
    the machine that matches the glob in `pattern`.

    Params:
    - pattern: string containing a glob of network interfaces to match; can
               match on both interface name (e.g. eth0) or IPv4/v6 address
    Returns: list of NetworkInterface objects
    Raises: OSError if getifaddrs(3) fails

    """
    ifap = POINTER(struct_ifaddrs)()
    # getifaddrs takes a *(struct ifaddrs) as the argument to put the interfaces
    # into; the return code is just a status.
    result = libc.getifaddrs(pointer(ifap))
    if result != 0:
        raise OSError(get_errno())
    del result

    try:
        # retval is a dict of NetworkInterfaces keyed by interface name
        # Each NetworkInterface has an 'addresses' field that is a dict of
        # addresses (as strings), keyed by address family.
        retval = {}

        for ifa in ifap_iter(ifap):
            name = ifa.ifa_name
            i = retval.get(name)
            if not i:
                i = retval[name] = NetworkInterface(name)
            family, addr = getfamaddr(ifa.ifa_addr.contents)
            if addr:
                address_list = i.addresses.setdefault(family, set())
                address_list.add(addr)

        # Filter the NetworkInterfaces by pattern; return them as a list.
        filtered_values = []
        for iface in retval.values():
            if iface_matches(iface, pattern):
                filtered_values.append(iface)
        return filtered_values
    finally:
        libc.freeifaddrs(ifap)


def iface_matches(network_iface, pattern):
    """Return if a network interface's name or associated addresses match
    the glob in `pattern`.

    Params:
    - network_iface: NetworkInterface object
    - pattern:       string containing a glob to match names or IP addresses
                     against
    Returns: True if the name or at least one of the IP addresses in
    `network_iface` matches `pattern`, False otherwise

    """
    if fnmatch.fnmatch(to_str(network_iface.name), pattern):
        return True
    for family, addrs in network_iface.addresses.items():
        if fnmatch.filter(addrs, pattern):
            return True
    return False


def ipv4_to_num(ipv4_str):
    """Return the 32-bit integer representation of a string containing an IPv4 address."""
    return struct.unpack('!L', socket.inet_aton(ipv4_str))[0]


def addr_in_network(addr, subnet, netmask):
    """Return if an IPv4 network (with subnet mask) contains the given IPv4 address"""
    subnet_num = ipv4_to_num(subnet)
    netmask_num = ipv4_to_num(netmask)
    addr_num = ipv4_to_num(addr)

    return (addr_num & netmask_num) == subnet_num


def is_private_addr(ipv4_addr):
    """Return if an IPv4 address is on a private network, i.e. one of
    - 192.168.0.0/16
    - 172.16.0.0/12
    - 10.0.0.0/8

    """
    for private in IPV4_PRIVATE_NETWORKS:
        if addr_in_network(addr=ipv4_addr, **private):
            return True
    else:
        return False


def is_loopback_addr(ipv4_addr):
    """Return if an IPv4 address refers to a loopback interface, i.e. on
    the 127.0.0.0/8 subnet.

    """
    return addr_in_network(addr=ipv4_addr, **IPV4_LOOPBACK_NETWORK)


def pick_condor_addr():
    """Return the IPv4 address HTCondor would pick based on the ones available
    on the host and the following htcondor config params:
    - BIND_ALL_INTERFACES: True if we should look through all available
      interfaces when searching for IPv4 addresses
    - NETWORK_INTERFACE: if BIND_ALL_INTERFACES is False, a glob pattern
      matching the names or IP addresses of interfaces to use; ignored if
      BIND_ALL_INTERFACES is True

    The following algorithm is used:
    - Collect all the IPv4 addresses of all the interfaces we are looking
      through based on the above two params
    - Rank the addresses as follows:
      - loopback addresses are ranked lowest (1)
      - private addresses are ranked medium (2)
      - other addresses are ranked highest (3)
    - A random address from the highest available rank is picked

    Return: the IPv4 address as a string of the best ('primary') address to use
    Exits:  if none of the selected interfaces have any IP addresses

    """
    iface_pattern = "*"
    if not htcondor.param["BIND_ALL_INTERFACES"]:
        iface_pattern = htcondor.param["NETWORK_INTERFACE"]
        log("HTCondor is considering all interfaces matching %s.", iface_pattern)
    else:
        log("HTCondor is considering all network interfaces and addresses.")
    ifaces = get_network_interfaces(iface_pattern)
    ranked = {}
    for iface in ifaces:
        ipv4_addrs = iface.addresses.get(AF_INET, [])
        for ipv4_addr in ipv4_addrs:
            if is_loopback_addr(ipv4_addr):
                addrs = ranked.setdefault(1, set())
                addrs.add(ipv4_addr)
            elif is_private_addr(ipv4_addr):
                addrs = ranked.setdefault(2, set())
                addrs.add(ipv4_addr)
            else:
                addrs = ranked.setdefault(3, set())
                addrs.add(ipv4_addr)
    if not ranked:
        if iface_pattern == "*":
            print("No available network addresses detected.", file=sys.stderr)
        else:
            print(f"NETWORK_INTERFACE value {iface_pattern} does not match any detected addresses or interfaces.", file=sys.stderr)
        print("Host network configuration not expected to work with HTCondor-CE.", file=sys.stderr)
        sys.exit(1)
    best = ranked[max(ranked)].pop()
    return best

def log(msg, *format):
    """Write message to stdout unless program was asked to be quiet."""
    if g_verbose:
        print(msg % format)


def parse_opts():
    """Parse command-line options.

    Sets:
    - g_verbose set to False if user passes -q/--quiet, True otherwise

    """
    global g_verbose
    parser = optparse.OptionParser()
    parser.add_option("-q", "--quiet", dest="quiet", action="store_true", default=False)
    opts, _ = parser.parse_args()
    if opts.quiet:
        g_verbose = False


def main():
    """Analyse the networking on the running host and determine if it is
    suitable for running HTCondor-CE.

    """
    parse_opts()

    log("Starting analysis of host networking for HTCondor-CE")

    # Determine the hostname HTCondor-CE will use based on the following, in
    # descending order of preference:
    # 1. NETWORK_HOSTNAME in the HTCondor-CE config
    # 2. CONDORCE_HOSTNAME in /etc/sysconfig/condor-ce
    # 3. The system hostname
    default_hostname = socket.gethostname().lower()
    log("System hostname: %s", default_hostname)

    fd = os.popen("source /etc/sysconfig/condor-ce && echo $CONDORCE_HOSTNAME")
    override_hostname = fd.read().strip().lower()
    if fd.close() or not override_hostname:
        override_hostname = default_hostname

    param_hostname = htcondor.param.get("NETWORK_HOSTNAME", default_hostname).lower()
    if param_hostname != default_hostname:
        log("System hostname (%s) overridden by NETWORK_HOSTNAME=%s in HTCondor-CE config.", default_hostname, param_hostname)
        default_hostname = param_hostname
    elif override_hostname != default_hostname:
        log("System hostname (%s) overridden by CONDORCE_HOSTNAME=%s in /etc/sysconfig/condor-ce.", default_hostname, override_hostname)
        default_hostname = override_hostname

    # Resolve the system hostname to its FQDN
    default_fqdn = socket.getfqdn(default_hostname).lower()
    if default_fqdn != default_hostname:
        log("Default hostname expands to FQDN %s", default_fqdn)
        log("WARNING: this could be problematic, but is likely OK.")
    else:
        log("FQDN matches hostname")

    # Attempt to validate the host certificate.
    # The path to the host certificate is taken from GSI_DAEMON_CERT in the
    # HTCondor-CE config, with "/etc/grid-security/hostcert.pem" as a default.
    #
    # Validation will fail if the last CommonName of the host cert does not
    # match the FQDN of the host, at which point, exit with an error.
    #
    # If the host cert is missing or unreadable, emit a warning but keep going.
    hostcert = htcondor.param.get("GSI_DAEMON_CERT", "/etc/grid-security/hostcert.pem")
    if os.access(hostcert, os.R_OK):
        fd = os.popen("openssl x509 -in %s -noout -subject -nameopt compat" % hostcert)
        dn = fd.read().strip()
        if fd.close() or not dn:
            print(f"WARNING: OpenSSL unable to parse host certificate {hostcert}; GSI configuration will likely fail.", file=sys.stderr)
        else:
            host_dn = dn.split("/CN=")[-1].lower()
            if host_dn != default_fqdn:
                print(f"Host certificate ({dn}) does not match HTCondor-CE hostname {default_fqdn}", file=sys.stderr)
                print("Host network configuration not expected to work with HTCondor-CE.", file=sys.stderr)
                sys.exit(1)
    else:
        print(f"WARNING: Unable to access certificate file at {hostcert}; skipping host certificate name check.", file=sys.stderr)


    # Forward resolution: resolve the host's FQDN on port 9618 to an IPv4 address.
    # Exit with an error if resolving fails.
    try:
        ipv4_forward_addr = socket.getaddrinfo(default_fqdn, 9618, socket.AF_INET)[0][4][0]
    except socket.gaierror as exc:
        print(f"Failed to resolve hostname {default_fqdn} to an address: {exc}", file=sys.stderr)
        print("Host network configuration not expected to work with HTCondor-CE.", file=sys.stderr)
        sys.exit(1)

    log("Forward resolution of hostname %s is %s.", default_fqdn, ipv4_forward_addr)

    # Reverse resolution: resolve the IPv4 address from the previous forward
    # lookup back to the host's FQDN.
    # Exit with an error if resolving fails.
    try:
        host_reverse = socket.getnameinfo((ipv4_forward_addr, 9618), socket.NI_NAMEREQD)[0].lower()
        host_reverse = socket.getfqdn(host_reverse).lower()
    except socket.gaierror as exc:
        print(f"Failed to resolve address {ipv4_forward_addr} to hostname: {exc}", file=sys.stderr)
        print("Host network configuration not expected to work with HTCondor-CE.", file=sys.stderr)
        sys.exit(1)

    log("Backward resolution of IPv4 %s is %s.", ipv4_forward_addr, host_reverse)

    # Exit with an error if forward and reverse resolution do not match.
    if host_reverse != default_fqdn:
        print(f"Backward resolution of address {ipv4_forward_addr} to {host_reverse} does not match default hostname {default_fqdn}.", file=sys.stderr)
        print("Host network configuration not expected to work with HTCondor-CE.", file=sys.stderr)
        sys.exit(1)

    log("Forward and backward resolution match!")

    # Pick the 'primary' IPv4 address according to the same algorithm HTCondor-CE
    # uses and warn the user if it is local or loopback.
    condor_addr = pick_condor_addr()
    log("HTCondor would pick address of %s as primary address.", condor_addr)

    if not htcondor.param['BIND_ALL_INTERFACES'] and (htcondor.param['NETWORK_INTERFACE'] != "*"):
        if is_private_addr(condor_addr):
            print(f"WARNING: The address {condor_addr} is a private address; using only this address may not result in a public-facing service.", file=sys.stderr)
        elif is_loopback_addr(condor_addr):
            print(f"WARNING: The address {condor_addr} is a loopback address; using only this address will not result in a public-facing service.", file=sys.stderr)

    # If the address selected by HTCondor-CE is the same as the address the
    # FQDN resolves to, we are done with network validation!
    if condor_addr == ipv4_forward_addr:
        log("HTCondor primary address %s matches system preferred address.", condor_addr)
        log("Host network configuration should work with HTCondor-CE")
        return

    # Reverse resolution 2: resolve the IPv4 address chosen by HTCondor-CE to
    # the host's FQDN.
    # Exit with an error if resolution fails.
    # Also exit with an error if forward and reverse resolution do not match.
    try:
        host_reverse = socket.getnameinfo((condor_addr, 9618), socket.NI_NAMEREQD)[0].lower()
        host_reverse = socket.getfqdn(host_reverse).lower()
    except socket.gaierror as exc:
        print(f"Failed to resolve condor address {condor_addr} to hostname: {exc}", file=sys.stderr)
        print("Host network configuration not expected to work with HTCondor-CE.", file=sys.stderr)
        sys.exit(1)

    if host_reverse != default_fqdn:
        print(f"Backward resolution of HTCondor address {condor_addr} to {host_reverse} does not match default hostname {default_fqdn}.", file=sys.stderr)
        if is_private_addr(ipv4_forward_addr) or is_loopback_addr(ipv4_forward_addr):
            print(f"The address of {ipv4_forward_addr} is not a publicly routable one; forcing use of that instead may not result in a usable service.", file=sys.stderr)
        else:
            print(f"Try setting BIND_ALL_INTERFACES=false, NETWORK_INTERFACE={ipv4_forward_addr} to force the HTCondor to use only {ipv4_forward_addr}.", file=sys.stderr)
        print("Host network configuration not expected to work with HTCondor-CE.", file=sys.stderr)
        sys.exit(1)
    else:
        log("Backward of resolution of HTCondor address %s matches default hostname of %s.", condor_addr, default_fqdn)

    try:
        addresses = socket.getaddrinfo(default_fqdn, 9618, socket.AF_INET)
    except socket.gaierror as exc:
        # Shouldn't happen - we already checked for this
        print(f"Failed to resolve hostname {default_fqdn} to an address: {exc}", file=sys.stderr)
        print("Host network configuration not expected to work with HTCondor-CE.", file=sys.stderr)
        sys.exit(1)

    # Verify that one of the addresses the FQDN resolves to is the address
    # HTCondor-CE has picked to bind to.
    # Exit with an error if it is not the case.
    found_forward_map = False
    for address_struct in addresses:
        address = address_struct[4][0]
        if address == condor_addr:
            found_forward_map = True
            break
    if not found_forward_map:
        print(f"HTCondor primary address {condor_addr} is not a valid address for default FQDN of {default_fqdn}.", file=sys.stderr)
        print("Host network configuration not expected to work with HTCondor-CE.", file=sys.stderr)
        sys.exit(1)
    else:
        log("Forward resolution of default hostname (%s) includes default HTCondor address %s.", default_fqdn, condor_addr)

    log("Host network configuration should work with HTCondor-CE")

if __name__ == '__main__':
    main()
