package net.corda.client.rpc.internal

import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.Serializer
import com.esotericsoftware.kryo.io.Input
import com.esotericsoftware.kryo.io.Output
import com.google.common.cache.Cache
import com.google.common.cache.CacheBuilder
import com.google.common.cache.RemovalCause
import com.google.common.cache.RemovalListener
import com.google.common.util.concurrent.SettableFuture
import com.google.common.util.concurrent.ThreadFactoryBuilder
import net.corda.client.rpc.RPCException
import net.corda.client.rpc.RPCSinceVersion
import net.corda.core.context.Actor
import net.corda.core.context.Trace
import net.corda.core.context.Trace.InvocationId
import net.corda.core.internal.LazyStickyPool
import net.corda.core.internal.LifeCycle
import net.corda.core.internal.ThreadBox
import net.corda.core.messaging.RPCOps
import net.corda.core.serialization.SerializationContext
import net.corda.core.serialization.serialize
import net.corda.core.utilities.Try
import net.corda.core.utilities.contextLogger
import net.corda.core.utilities.debug
import net.corda.core.utilities.getOrThrow
import net.corda.nodeapi.RPCApi
import net.corda.nodeapi.internal.DeduplicationChecker
import org.apache.activemq.artemis.api.core.RoutingType
import org.apache.activemq.artemis.api.core.SimpleString
import org.apache.activemq.artemis.api.core.client.*
import org.apache.activemq.artemis.api.core.client.ActiveMQClient.DEFAULT_ACK_BATCH_SIZE
import rx.Notification
import rx.Observable
import rx.subjects.UnicastSubject
import java.lang.reflect.InvocationHandler
import java.lang.reflect.Method
import java.time.Instant
import java.util.*
import java.util.concurrent.*
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.atomic.AtomicLong
import kotlin.reflect.jvm.javaMethod

/**
 * This class provides a proxy implementation of an RPC interface for RPC clients. It translates API calls to lower-level
 * RPC protocol messages. For this protocol see [RPCApi].
 *
 * When a method is called on the interface the arguments are serialised and the request is forwarded to the server. The
 * server then executes the code that implements the RPC and sends a reply.
 *
 * An RPC reply may contain [Observable]s, which are serialised simply as unique IDs. On the client side we create a
 * [UnicastSubject] for each such ID. Subsequently the server may send observations attached to this ID, which are
 * forwarded to the [UnicastSubject]. Note that the observations themselves may contain further [Observable]s, which are
 * handled in the same way.
 *
 * To do the above we take advantage of Kryo's datastructure traversal. When the client is deserialising a message from
 * the server that may contain Observables it is supplied with an [ObservableContext] that exposes the map used to demux
 * the observations. When an [Observable] is encountered during traversal a new [UnicastSubject] is added to the map and
 * we carry on. Each observation later contains the corresponding Observable ID, and we just forward that to the
 * associated [UnicastSubject].
 *
 * The client may signal that it no longer consumes a particular [Observable]. This may be done explicitly by
 * unsubscribing from the [Observable], or if the [Observable] is garbage collected the client will eventually
 * automatically signal the server. This is done using a cache that holds weak references to the [UnicastSubject]s.
 * The cleanup happens in batches using a dedicated reaper, scheduled on [reaperExecutor].
 */
class RPCClientProxyHandler(
        private val rpcConfiguration: RPCClientConfiguration,
        private val rpcUsername: String,
        private val rpcPassword: String,
        private val serverLocator: ServerLocator,
        private val clientAddress: SimpleString,
        private val rpcOpsClass: Class<out RPCOps>,
        serializationContext: SerializationContext,
        private val sessionId: Trace.SessionId,
        private val externalTrace: Trace?,
        private val impersonatedActor: Actor?
) : InvocationHandler {

    private enum class State {
        UNSTARTED,
        SERVER_VERSION_NOT_SET,
        STARTED,
        FINISHED
    }

    private val lifeCycle = LifeCycle(State.UNSTARTED)

    private companion object {
        private val log = contextLogger()
        // To check whether toString() is being invoked
        val toStringMethod: Method = Object::toString.javaMethod!!

        private fun addRpcCallSiteToThrowable(throwable: Throwable, callSite: Throwable) {
            var currentThrowable = throwable
            while (true) {
                val cause = currentThrowable.cause
                if (cause == null) {
                    currentThrowable.initCause(callSite)
                    break
                } else {
                    currentThrowable = cause
                }
            }
        }
    }

    // Used for reaping
    private var reaperExecutor: ScheduledExecutorService? = null
    // Used for sending
    private var sendExecutor: ExecutorService? = null

    // A sticky pool for running Observable.onNext()s. We need the stickiness to preserve the observation ordering.
    private val observationExecutorThreadFactory = ThreadFactoryBuilder().setNameFormat("rpc-client-observation-pool-%d").setDaemon(true).build()
    private val observationExecutorPool = LazyStickyPool(rpcConfiguration.observationExecutorPoolSize) {
        Executors.newFixedThreadPool(1, observationExecutorThreadFactory)
    }

    // Holds the RPC reply futures.
    private val rpcReplyMap = RpcReplyMap()
    // Optionally holds RPC call site stack traces to be shown on errors/warnings.
    private val callSiteMap = if (rpcConfiguration.trackRpcCallSites) CallSiteMap() else null
    // Holds the Observables and a reference store to keep Observables alive when subscribed to.
    private val observableContext = ObservableContext(
            callSiteMap = callSiteMap,
            observableMap = createRpcObservableMap(),
            hardReferenceStore = Collections.synchronizedSet(mutableSetOf<Observable<*>>())
    )
    // Holds a reference to the scheduled reaper.
    private var reaperScheduledFuture: ScheduledFuture<*>? = null
    // The protocol version of the server, to be initialised to the value of [RPCOps.protocolVersion]
    private var serverProtocolVersion: Int? = null

    // Stores the Observable IDs that are already removed from the map but are not yet sent to the server.
    private val observablesToReap = ThreadBox(object {
        var observables = ArrayList<InvocationId>()
    })
    private val serializationContextWithObservableContext = RpcClientObservableSerializer.createContext(serializationContext, observableContext)

    private fun createRpcObservableMap(): RpcObservableMap {
        val onObservableRemove = RemovalListener<InvocationId, UnicastSubject<Notification<*>>> {
            val observableId = it.key!!
            val rpcCallSite = callSiteMap?.remove(observableId)
            if (it.cause == RemovalCause.COLLECTED) {
                log.warn(listOf(
                        "A hot observable returned from an RPC was never subscribed to.",
                        "This wastes server-side resources because it was queueing observations for retrieval.",
                        "It is being closed now, but please adjust your code to call .notUsed() on the observable",
                        "to close it explicitly. (Java users: subscribe to it then unsubscribe). This warning",
                        "will appear less frequently in future versions of the platform and you can ignore it",
                        "if you want to.").joinToString(" "), rpcCallSite)
            }
            observablesToReap.locked { observables.add(observableId) }
        }
        return CacheBuilder.newBuilder().
                weakValues().
                removalListener(onObservableRemove).
                concurrencyLevel(rpcConfiguration.cacheConcurrencyLevel).
                build()
    }

    private var sessionFactory: ClientSessionFactory? = null
    private var producerSession: ClientSession? = null
    private var consumerSession: ClientSession? = null
    private var rpcProducer: ClientProducer? = null
    private var rpcConsumer: ClientConsumer? = null

    private val deduplicationChecker = DeduplicationChecker(rpcConfiguration.deduplicationCacheExpiry)
    private val deduplicationSequenceNumber = AtomicLong(0)

    /**
     * Start the client. This creates the per-client queue, starts the consumer session and the reaper.
     */
    fun start() {
        lifeCycle.requireState(State.UNSTARTED)
        reaperExecutor = Executors.newScheduledThreadPool(
                1,
                ThreadFactoryBuilder().setNameFormat("rpc-client-reaper-%d").setDaemon(true).build()
        )
        sendExecutor = Executors.newSingleThreadExecutor(
                ThreadFactoryBuilder().setNameFormat("rpc-client-sender-%d").build()
        )
        reaperScheduledFuture = reaperExecutor!!.scheduleAtFixedRate(
                this::reapObservablesAndNotify,
                rpcConfiguration.reapInterval.toMillis(),
                rpcConfiguration.reapInterval.toMillis(),
                TimeUnit.MILLISECONDS
        )
        sessionFactory = serverLocator.createSessionFactory()
        producerSession = sessionFactory!!.createSession(rpcUsername, rpcPassword, false, true, true, false, DEFAULT_ACK_BATCH_SIZE)
        rpcProducer = producerSession!!.createProducer(RPCApi.RPC_SERVER_QUEUE_NAME)
        consumerSession = sessionFactory!!.createSession(rpcUsername, rpcPassword, false, true, true, false, DEFAULT_ACK_BATCH_SIZE)
        consumerSession!!.createTemporaryQueue(clientAddress, RoutingType.ANYCAST, clientAddress)
        rpcConsumer = consumerSession!!.createConsumer(clientAddress)
        rpcConsumer!!.setMessageHandler(this::artemisMessageHandler)
        lifeCycle.transition(State.UNSTARTED, State.SERVER_VERSION_NOT_SET)
        consumerSession!!.start()
        producerSession!!.start()
    }

    // This is the general function that transforms a client side RPC to internal Artemis messages.
    override fun invoke(proxy: Any, method: Method, arguments: Array<out Any?>?): Any? {
        lifeCycle.requireState { it == State.STARTED || it == State.SERVER_VERSION_NOT_SET }
        checkProtocolVersion(method)
        if (method == toStringMethod) {
            return "Client RPC proxy for $rpcOpsClass"
        }
        if (consumerSession!!.isClosed) {
            throw RPCException("RPC Proxy is closed")
        }

        val replyId = InvocationId.newInstance()
        callSiteMap?.set(replyId, Throwable("<Call site of root RPC '${method.name}'>"))
        try {
            val serialisedArguments = (arguments?.toList() ?: emptyList()).serialize(context = serializationContextWithObservableContext)
            val request = RPCApi.ClientToServer.RpcRequest(
                    clientAddress,
                    method.name,
                    serialisedArguments,
                    replyId,
                    sessionId,
                    externalTrace,
                    impersonatedActor
            )
            val replyFuture = SettableFuture.create<Any>()
            require(rpcReplyMap.put(replyId, replyFuture) == null) {
                "Generated several RPC requests with same ID $replyId"
            }
            sendMessage(request)
            return replyFuture.getOrThrow()
        } catch (e: RuntimeException) {
            // Already an unchecked exception, so just rethrow it
            throw e
        } catch (e: Exception) {
            // This must be a checked exception, so wrap it
            throw RPCException(e.message ?: "", e)
        } finally {
            callSiteMap?.remove(replyId)
        }
    }

    private fun sendMessage(message: RPCApi.ClientToServer) {
        val artemisMessage = producerSession!!.createMessage(false)
        message.writeToClientMessage(artemisMessage)
        sendExecutor!!.submit {
            artemisMessage.putLongProperty(RPCApi.DEDUPLICATION_SEQUENCE_NUMBER_FIELD_NAME, deduplicationSequenceNumber.getAndIncrement())
            log.debug { "-> RPC -> $message" }
            rpcProducer!!.send(artemisMessage)
        }
    }

    // The handler for Artemis messages.
    private fun artemisMessageHandler(message: ClientMessage) {
        val serverToClient = RPCApi.ServerToClient.fromClientMessage(serializationContextWithObservableContext, message)
        val deduplicationSequenceNumber = message.getLongProperty(RPCApi.DEDUPLICATION_SEQUENCE_NUMBER_FIELD_NAME)
        if (deduplicationChecker.checkDuplicateMessageId(serverToClient.deduplicationIdentity, deduplicationSequenceNumber)) {
            log.info("Message duplication detected, discarding message")
            return
        }
        log.debug { "Got message from RPC server $serverToClient" }
        when (serverToClient) {
            is RPCApi.ServerToClient.RpcReply -> {
                val replyFuture = rpcReplyMap.remove(serverToClient.id)
                if (replyFuture == null) {
                    log.error("RPC reply arrived to unknown RPC ID ${serverToClient.id}, this indicates an internal RPC error.")
                } else {
                    val result = serverToClient.result
                    when (result) {
                        is Try.Success -> replyFuture.set(result.value)
                        is Try.Failure -> {
                            val rpcCallSite = callSiteMap?.get(serverToClient.id)
                            if (rpcCallSite != null) addRpcCallSiteToThrowable(result.exception, rpcCallSite)
                            replyFuture.setException(result.exception)
                        }
                    }
                }
            }
            is RPCApi.ServerToClient.Observation -> {
                val observable = observableContext.observableMap.getIfPresent(serverToClient.id)
                if (observable == null) {
                    log.debug("Observation ${serverToClient.content} arrived to unknown Observable with ID ${serverToClient.id}. " +
                            "This may be due to an observation arriving before the server was " +
                            "notified of observable shutdown")
                } else {
                    // We schedule the onNext() on an executor sticky-pooled based on the Observable ID.
                    observationExecutorPool.run(serverToClient.id) { executor ->
                        executor.submit {
                            val content = serverToClient.content
                            if (content.isOnCompleted || content.isOnError) {
                                observableContext.observableMap.invalidate(serverToClient.id)
                            }
                            // Add call site information on error
                            if (content.isOnError) {
                                val rpcCallSite = callSiteMap?.get(serverToClient.id)
                                if (rpcCallSite != null) addRpcCallSiteToThrowable(content.throwable, rpcCallSite)
                            }
                            observable.onNext(content)
                        }
                    }
                }
            }
        }
        message.acknowledge()
    }

    /**
     * Closes this handler without notifying observables.
     * This method clears up only local resources and as such does not block on any network resources.
     */
    fun forceClose() {
        close(false)
    }

    /**
     * Closes this handler and sends notifications to all observables, so it can immediately clean up resources.
     * Notifications sent to observables are to be acknowledged, therefore this call blocks until all acknowledgements are received.
     * If this is not convenient see the [forceClose] method.
     * If an observable is not accessible this method may block for a duration of the message broker timeout.
     */
    fun notifyServerAndClose() {
        close(true)
    }

    /**
     * Closes the RPC proxy. Reaps all observables, shuts down the reaper, closes all sessions and executors.
     * When observables are to be notified (i.e. the [notify] parameter is true),
     * the method blocks until all the messages are acknowledged by the observables.
     * Note: If any of the observables is inaccessible, the method blocks for the duration of the timeout set on the message broker.
     *
     * @param notify whether to notify observables or not.
     */
    private fun close(notify: Boolean = true) {
        sessionFactory?.close()
        reaperScheduledFuture?.cancel(false)
        observableContext.observableMap.invalidateAll()
        reapObservables(notify)
        reaperExecutor?.shutdownNow()
        sendExecutor?.shutdownNow()
        // Note the ordering is important, we shut down the consumer *before* the observation executor, otherwise we may
        // leak borrowed executors.
        val observationExecutors = observationExecutorPool.close()
        observationExecutors.forEach { it.shutdownNow() }
        lifeCycle.justTransition(State.FINISHED)
    }

    /**
     * Check the [RPCSinceVersion] of the passed in [calledMethod] against the server's protocol version.
     */
    private fun checkProtocolVersion(calledMethod: Method) {
        val serverProtocolVersion = serverProtocolVersion
        if (serverProtocolVersion == null) {
            lifeCycle.requireState(State.SERVER_VERSION_NOT_SET)
        } else {
            lifeCycle.requireState(State.STARTED)
            val sinceVersion = calledMethod.getAnnotation(RPCSinceVersion::class.java)?.version ?: 0
            if (sinceVersion > serverProtocolVersion) {
                throw UnsupportedOperationException("Method $calledMethod was added in RPC protocol version $sinceVersion but the server is running $serverProtocolVersion")
            }
        }
    }

    /**
     * Set the server's protocol version. Note that before doing so the client is not considered fully started, although
     * RPCs already may be called with it.
     */
    internal fun setServerProtocolVersion(version: Int) {
        if (serverProtocolVersion == null) {
            serverProtocolVersion = version
        } else {
            throw IllegalStateException("setServerProtocolVersion called, but the protocol version was already set!")
        }
        lifeCycle.transition(State.SERVER_VERSION_NOT_SET, State.STARTED)
    }

    private fun reapObservablesAndNotify() = reapObservables()

    private fun reapObservables(notify: Boolean = true) {
        observableContext.observableMap.cleanUp()
        if (!notify) return
        val observableIds = observablesToReap.locked {
            if (observables.isNotEmpty()) {
                val temporary = observables
                observables = ArrayList()
                temporary
            } else {
                null
            }
        }
        if (observableIds != null) {
            log.debug { "Reaping ${observableIds.size} observables" }
            sendMessage(RPCApi.ClientToServer.ObservablesClosed(observableIds))
        }
    }
}

private typealias RpcObservableMap = Cache<InvocationId, UnicastSubject<Notification<*>>>
private typealias RpcReplyMap = ConcurrentHashMap<InvocationId, SettableFuture<Any?>>
private typealias CallSiteMap = ConcurrentHashMap<InvocationId, Throwable?>

/**
 * Holds a context available during Kryo deserialisation of messages that are expected to contain Observables.
 *
 * @param observableMap holds the Observables that are ultimately exposed to the user.
 * @param hardReferenceStore holds references to Observables we want to keep alive while they are subscribed to.
 */
data class ObservableContext(
        val callSiteMap: CallSiteMap?,
        val observableMap: RpcObservableMap,
        val hardReferenceStore: MutableSet<Observable<*>>
)

/**
 * A [Serializer] to deserialise Observables once the corresponding Kryo instance has been provided with an [ObservableContext].
 */
object RpcClientObservableSerializer : Serializer<Observable<*>>() {
    private object RpcObservableContextKey

    fun createContext(serializationContext: SerializationContext, observableContext: ObservableContext): SerializationContext {
        return serializationContext.withProperty(RpcObservableContextKey, observableContext)
    }

    private fun <T> pinInSubscriptions(observable: Observable<T>, hardReferenceStore: MutableSet<Observable<*>>): Observable<T> {
        val refCount = AtomicInteger(0)
        return observable.doOnSubscribe {
            if (refCount.getAndIncrement() == 0) {
                require(hardReferenceStore.add(observable)) { "Reference store already contained reference $this on add" }
            }
        }.doOnUnsubscribe {
            if (refCount.decrementAndGet() == 0) {
                require(hardReferenceStore.remove(observable)) { "Reference store did not contain reference $this on remove" }
            }
        }
    }

    override fun read(kryo: Kryo, input: Input, type: Class<Observable<*>>): Observable<Any> {
        val observableContext = kryo.context[RpcObservableContextKey] as ObservableContext
        val observableId = input.readInvocationId() ?: throw IllegalStateException("Unable to read invocationId from Input.")
        val observable = UnicastSubject.create<Notification<*>>()
        require(observableContext.observableMap.getIfPresent(observableId) == null) {
            "Multiple Observables arrived with the same ID $observableId"
        }
        val rpcCallSite = getRpcCallSite(kryo, observableContext)
        observableContext.observableMap.put(observableId, observable)
        observableContext.callSiteMap?.put(observableId, rpcCallSite)
        // We pin all Observables into a hard reference store (rooted in the RPC proxy) on subscription so that users
        // don't need to store a reference to the Observables themselves.
        return pinInSubscriptions(observable, observableContext.hardReferenceStore).doOnUnsubscribe {
            // This causes Future completions to give warnings because the corresponding OnComplete sent from the server
            // will arrive after the client unsubscribes from the observable and consequently invalidates the mapping.
            // The unsubscribe is due to [ObservableToFuture]'s use of first().
            observableContext.observableMap.invalidate(observableId)
        }.dematerialize()
    }

    private fun Input.readInvocationId() : InvocationId? {

        val value = readString() ?: return null
        val timestamp = readLong()
        return InvocationId(value, Instant.ofEpochMilli(timestamp))
    }

    override fun write(kryo: Kryo, output: Output, observable: Observable<*>) {
        throw UnsupportedOperationException("Cannot serialise Observables on the client side")
    }

    private fun getRpcCallSite(kryo: Kryo, observableContext: ObservableContext): Throwable? {
        val rpcRequestOrObservableId = kryo.context[RPCApi.RpcRequestOrObservableIdKey] as InvocationId
        return observableContext.callSiteMap?.get(rpcRequestOrObservableId)
    }
}