#!/usr/bin/python3
"""
Spins up a node with old version and wait until it produces some blocks.
Shutdowns the node and restarts with the same data folder with the new binary.
Makes sure that the node can still produce blocks.
"""

import logging
import sys
import pathlib

sys.path.append(str(pathlib.Path(__file__).resolve().parents[2] / 'lib'))

import branches
import cluster
from transaction import sign_deploy_contract_tx, sign_function_call_tx, sign_staking_tx
import utils

logging.basicConfig(level=logging.INFO)

NUM_SHARDS = 4
EPOCH_LENGTH = 5
# Used while the node is running with the old binary.
LARGE_GC_NUM_EPOCHS_TO_KEEP = 15
# Used while the node is running with the new binary.
DEFAULT_GC_NUM_EPOCHS_TO_KEEP = 5

# Config to track all shards.
node_config = {
    "tracked_shards": list(range(NUM_SHARDS)),
    # We should have large enough GC window while the node running with the old binary (2.0), because we need to
    # bootstrap the congestion info from the genesis block and the genesis block should not be garbage collected.
    # The new binary (2.1) has a mechanism to save the congestion info on disk to be resilient against GC of
    # genesis block, however the old binary does not have this mechanism, thus we need to prevent GC to kick in
    # before the node transitions to the new binary. Note that, after restarting the node with the new binary
    # we can switch back to the usual GC window (see below).
    # TODO(#11902): This change for GC can be removed when the old binary version becomes 2.1 in this test.
    "gc_num_epochs_to_keep": LARGE_GC_NUM_EPOCHS_TO_KEEP,
}


def deploy_contract(node):
    hash_ = node.get_latest_block().hash_bytes
    tx = sign_deploy_contract_tx(node.signer_key, utils.load_test_contract(),
                                 10, hash_)
    node.send_tx_and_wait(tx, timeout=15)
    utils.wait_for_blocks(node, count=3)


def send_some_tx(node):
    # Write 10 values to storage
    nonce = node.get_nonce_for_pk(node.signer_key.account_id,
                                  node.signer_key.pk) + 10
    for i in range(10):
        hash_ = node.get_latest_block().hash_bytes
        keyvalue = bytearray(16)
        keyvalue[0] = (nonce // 10) % 256
        keyvalue[8] = (nonce // 10) % 255
        tx2 = sign_function_call_tx(node.signer_key, node.signer_key.account_id,
                                    'write_key_value', bytes(keyvalue),
                                    10000000000000, 100000000000, nonce, hash_)
        nonce += 10
        res = node.send_tx_and_wait(tx2, timeout=15)
        assert 'error' not in res, res
        assert 'Failure' not in res['result']['status'], res
    utils.wait_for_blocks(node, count=3)


# Unstake and restake validator running `node` to ensure that some validator
# kickout is recorded on DB.
# Reproduces issue #11569.
def unstake_and_stake(node, tx_sender_node):
    account = tx_sender_node.get_account(node.signer_key.account_id)['result']
    full_balance = int(account['amount']) + int(account['locked'])

    logging.info(f'Unstaking {node.signer_key.account_id}...')
    nonce = tx_sender_node.get_nonce_for_pk(node.signer_key.account_id,
                                            node.signer_key.pk) + 10

    hash_ = tx_sender_node.get_latest_block().hash_bytes
    tx = sign_staking_tx(node.signer_key, node.validator_key, 0, nonce, hash_)

    nonce += 10
    res = tx_sender_node.send_tx_and_wait(tx, timeout=15)
    assert 'error' not in res, res
    assert 'Failure' not in res['result']['status'], res
    utils.wait_for_blocks(tx_sender_node, count=EPOCH_LENGTH * 2)

    logging.info(f'Restaking {node.signer_key.account_id}...')
    tx = sign_staking_tx(node.signer_key, node.validator_key, full_balance // 2,
                         nonce, hash_)
    nonce += 10
    res = tx_sender_node.send_tx_and_wait(tx, timeout=15)
    assert 'error' not in res, res
    assert 'Failure' not in res['result']['status'], res
    utils.wait_for_blocks(tx_sender_node, count=EPOCH_LENGTH * 2)


def main():
    executables = branches.prepare_ab_test()
    node_root = utils.get_near_tempdir('db_migration', clean=True)

    logging.info(f"The near root is {executables.stable.root}...")
    logging.info(f"The node root is {node_root}...")

    config = executables.stable.node_config()
    logging.info("Starting stable nodes...")
    nodes = cluster.start_cluster(
        2,
        0,
        NUM_SHARDS,
        config,
        [['epoch_length', EPOCH_LENGTH], [
            "block_producer_kickout_threshold", 0
        ], ["chunk_producer_kickout_threshold", 0]],
        # Make sure nodes track all shards to:
        # 1. Avoid state sync after restaking
        # 2. Respond to all view queries
        {
            0: node_config,
            1: node_config,
        })
    node = nodes[0]

    logging.info("Running the stable node...")
    utils.wait_for_blocks(node, count=EPOCH_LENGTH)
    logging.info("Blocks are being produced, sending some tx...")
    deploy_contract(node)
    send_some_tx(node)
    unstake_and_stake(nodes[1], node)

    height, _ = utils.wait_for_blocks(node, count=1)
    assert height < EPOCH_LENGTH * LARGE_GC_NUM_EPOCHS_TO_KEEP, "Node should run without GC"

    node.kill()

    logging.info(
        "Stable node has produced blocks... Stopping the stable node... ")

    # Run new node and verify it runs for a few more blocks.
    logging.info("Starting the current node...")
    node.near_root = executables.current.root
    node.binary_name = executables.current.neard
    node.change_config({"gc_num_epochs_to_keep": DEFAULT_GC_NUM_EPOCHS_TO_KEEP})
    node.start(boot_node=node)

    logging.info("Running the current node...")
    utils.wait_for_blocks(node, count=EPOCH_LENGTH * 4)
    logging.info("Blocks are being produced, sending some tx...")
    send_some_tx(node)

    logging.info(
        "Current node has produced blocks... Stopping the current node... ")

    node.kill()

    logging.info("Restarting the current node...")

    node.start(boot_node=node)
    utils.wait_for_blocks(node, count=EPOCH_LENGTH * 4)


if __name__ == "__main__":
    main()
