#!/usr/bin/env python3
#
# Copyright 2019 Orange
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# 	You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# 	See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import sys
import json
import re
from subprocess import check_output, STDOUT, CalledProcessError
from os.path import basename
from random import shuffle
from itertools import groupby

PAUSE_LABELS_OP = ["ToDo", "Paused"]
NO_PODS_FOUND = "No pods found for operation"
NO_ONGOING_OP = "--selector=operation-status notin (Ongoing, Finalizing)"
RE_RACK_NAME = re.compile(r"(\w+)\W(\w+)")

def k_apply_with_input(input, error, *args):
    try:
        resource = json.loads(input)
        del resource["metadata"]["resourceVersion"]
        input = json.dumps(resource)
        check_output(["kubectl", "apply", "-f", "-"], stderr=STDOUT, input=str.encode(input))
    except CalledProcessError:
        die(error)

def k(*args):
    params = list(args)
    try:
        out = check_output(["kubectl"] + params, stderr=STDOUT)
        result = out.decode("utf-8")
        if result:
            if "-o" in " ".join(params):
                return result
            return [r.split() for r in result.split("\n")[1:-1]]
    except CalledProcessError:
        return []

def die(msg):
    print(f"{msg}. Have to stop here ...")
    sys.exit(1)

def pod_is_mandatory(name):
    if not name:
        die(f"Pod {name} not found")
    return name

def get_pods(args, option=[]):
    if args.prefix:
        pods = grep_pods(args.prefix, option)
    else:
        pods = grep_pods(pod_is_mandatory(args.pod),
                         option + [f"--field-selector=metadata.name={args.pod}"])
    if not pods:
        die(NO_PODS_FOUND)
    return pods

def grep_pods(prefix, option=[]):
    # We only look for running pods
    opt = option + ["--field-selector=status.phase=Running"]
    # Pods having an operation with the provided status or with no ongoing operation
    if not any(['operation-status' in o for o in option]):
        opt.append(NO_ONGOING_OP)
    grp = []
    for key, g in groupby(sorted(opt), lambda x:x[:x.index('=')+1]):
        grp.append(key + ','.join([opt[len(key):] for opt in g]))
    pods = [p[0] for p in k(*(["get", "pods"] + grp)) if prefix in p[0]]
    if not pods:
        die(NO_PODS_FOUND)
    return pods

def get_namespace():
    return k('config', 'view', '--minify', '--output', 'jsonpath={..namespace}') or 'default'

def set_pod_label(pod, operation, status="ToDo", argument=None):
    print(f"Set status of operation {operation} on pod {pod} to {status}")
    labels = [f"operation-name={operation}", f"operation-status={status}"]
    if argument:
        labels.append(f"operation-argument={argument}")
    k("label", "pods", pod, *labels, "--overwrite")

# Returns one running pod in the crd
def available_pod_in_crd(crd):
    pods = grep_pods(crd)
    crd_content = k("get", "cassandracluster", crd, "-o", "json")
    if not crd_content:
        die(f"crd {crd} not found")
    crd_content = json.loads(crd_content)["status"]["cassandraRackStatus"]
    busy_pods = set()
    for key in [k for k in crd_content.keys()
                if type(crd_content[k]) is dict and "podLastOperation" in crd_content[k].keys()]:
        pod_last_op = crd_content[key]["podLastOperation"]
        if "pods" in pod_last_op:
            busy_pods.update(set(pod_last_op["pods"]))
    # Let's not keep the default ascii sort
    shuffle(busy_pods)
    return next(p for p in pods if p not in busy_pods)


class Command(object):
    def __init__(self):
        plugin = basename(sys.argv[0])
        parser = argparse.ArgumentParser(
            description='Kubernetes plugin used to trigger operations',
            usage=f"""{plugin} <command> [<args>]

The available commands are:
   cleanup
   upgradesstables
   rebuild
   remove
   restart
   pause
   unpause

For more information you can run {plugin} <command> --help
""")
        parser.add_argument('command', help='Subcommand to run')
        args = parser.parse_args(sys.argv[1:2])
        self._command_must_exist(args.command, parser)
        print(f"Namespace {get_namespace()}")

        # Call function corresponding to the command passed
        getattr(self, args.command)()

    def _command_must_exist(self, operation, parser):
        if not hasattr(self, operation):
            print(f"Unrecognized operation {operation}")
            parser.print_help()
            exit(1)

    def _parse_operation_options(self, operation, argv_index=2):
        parser = argparse.ArgumentParser(operation)
        group = parser.add_mutually_exclusive_group(required=True)
        group.add_argument('--pod')
        group.add_argument('--prefix')
        return parser.parse_args(sys.argv[argv_index:])

    def _simple_operation(self, operation):
        args = self._parse_operation_options(operation)
        pods = get_pods(args)
        for pod in pods:
            set_pod_label(pod, operation)

    def _pause_operation(self, pause):
        # We expect a supported operation to be passed
        parser = argparse.ArgumentParser("pause")
        parser.add_argument('operation', help=f'Operation to {"un" if not pause else ""}pause')
        args = parser.parse_args(sys.argv[2:3])
        self._command_must_exist(args.operation, parser)
        operation = args.operation
        args = self._parse_operation_options(args.operation, argv_index=3)
        old_status, new_status = pause and PAUSE_LABELS_OP or reversed(PAUSE_LABELS_OP)
        pods = get_pods(args, [f"--selector=operation-name={operation},operation-status={old_status}"])
        for pod in pods:
            set_pod_label(pod, operation, status=new_status)

    def pause(self):
        self._pause_operation(True)

    def unpause(self):
        self._pause_operation(False)

    def cleanup(self):
        self._simple_operation(self.cleanup.__name__)

    def upgradesstables(self):
        self._simple_operation(self.upgradesstables.__name__)

    def _rolling_restart(self, topology, matching_test):
        m = {f"{d['name']}.{r['name']}":r for d in topology for r in d['rack']
               if matching_test(d['name'], r['name'])}
        if not m:
            die("Can't match dc or rack")
        for rack, r in m.items():
            print(f"Trigger {self.restart.__name__} of {rack}")
            r['rollingRestart'] = True

    def restart(self):
        parser = argparse.ArgumentParser(self.restart.__name__)
        parser.add_argument('--crd', required=True)
        group = parser.add_mutually_exclusive_group(required=True)
        group.add_argument('--rack', nargs='+')
        group.add_argument('--dc', nargs='+')
        group.add_argument('--full', action='store_true')
        args = parser.parse_args(sys.argv[2:])
        crd_content = k("get", "cassandracluster", args.crd, "-o", "json")
        if not crd_content:
            die(f"crd {crd} not found")
        crd_content = json.loads(crd_content)
        topology = crd_content["spec"]["topology"]["dc"]

        if args.rack:
            for rack in args.rack:
                m = RE_RACK_NAME.match(rack)
                if not m:
                    die(f"Can't extract dcname and rackname from {rack}")
                dc, rack = m.groups()
                self._rolling_restart(topology, lambda d,r: d == dc and r == rack)
        elif args.dc:
            for dc in args.dc:
                self._rolling_restart(topology, lambda d,r: d == dc)
        elif args.full:
            self._rolling_restart(topology, lambda d,r: True)

        k_apply_with_input(json.dumps(crd_content), f"Can't update CassandraCluster {args.crd}")

    def remove(self):
        parser = argparse.ArgumentParser(self.remove.__name__)
        parser.add_argument('--pod')
        parser.add_argument('--previous-ip', default="")
        group = parser.add_mutually_exclusive_group(required=True)
        group.add_argument('--from-pod')
        group.add_argument('--crd')
        args = parser.parse_args(sys.argv[2:])
        pod = pod_is_mandatory(args.pod)
        from_pod = args.from_pod

        if not (args.pod or args.previous_ip):
            die("At least one option must be used between --pod and --previous-ip")

        if not from_pod:
            from_pod = available_pod_in_crd(args.crd)

        print(f"Trigger {self.remove.__name__} of pod {pod} from pod {from_pod}")
        set_pod_label(from_pod, self.remove.__name__, argument=f"{args.pod}_{args.previous_ip}")

    def rebuild(self):
        parser = argparse.ArgumentParser(self.rebuild.__name__)
        group = parser.add_mutually_exclusive_group(required=True)
        group.add_argument('--pod')
        group.add_argument('--prefix')
        parser.add_argument('from_dc', metavar='from-dc')
        args = parser.parse_args(sys.argv[2:])
        pods = get_pods(args)
        for pod in pods:
            set_pod_label(pod, self.rebuild.__name__, argument=args.from_dc)

    def replace(self):
        parser = argparse.ArgumentParser(self.replace.__name__)
        parser.add_argument('--pod', required=True)
        parser.add_argument('--previous-ip', required=True)
        args = parser.parse_args(sys.argv[2:])
        pod = pod_is_mandatory(args.pod)
        pre_run = f"test \"$(hostname)\" == '{args.pod}' && echo -Dcassandra.replace_address_first_boot={args.previous_ip} >> /etc/cassandra/jvm.options"
        crds = k("get", "cassandracluster")
        if not crds:
            die("No crds found")
        crd = next(c[0] for c in crds if c[0] in args.pod)
        config_map = k("get", "cassandracluster", crd, "--output=jsonpath={.spec.configMapName}")
        if not config_map:
            die("No ConfigMap found")
        result = json.loads(k("get", "configmap", config_map, "-o", "json"))
        result["data"]["pre_run.sh"] = pre_run
        print("Update pre_run.sh in ConfigMap {config_map}")
        k_apply_with_input(json.dumps(result), f"Can't update ConfigMap {config_map}")
        print(f"Delete pvc data-{args.pod}")
        k("delete", "pvc", f"data-{args.pod}")


if __name__ == '__main__':
    Command()
