package net.corda.flow.pipeline.impl

import net.corda.data.flow.output.FlowStatus
import net.corda.data.flow.state.waiting.WaitingFor
import net.corda.flow.fiber.cache.FlowFiberCache
import net.corda.flow.maintenance.CheckpointCleanupHandler
import net.corda.flow.pipeline.FlowEventExceptionProcessor
import net.corda.flow.pipeline.events.FlowEventContext
import net.corda.flow.pipeline.exceptions.FlowEventException
import net.corda.flow.pipeline.exceptions.FlowFatalException
import net.corda.flow.pipeline.exceptions.FlowMarkedForKillException
import net.corda.flow.pipeline.exceptions.FlowPlatformException
import net.corda.flow.pipeline.exceptions.FlowProcessingExceptionTypes.PLATFORM_ERROR
import net.corda.flow.pipeline.factory.FlowMessageFactory
import net.corda.flow.pipeline.factory.FlowRecordFactory
import net.corda.flow.pipeline.addTerminationKeyToMeta
import net.corda.flow.pipeline.sessions.FlowSessionManager
import net.corda.flow.state.FlowCheckpoint
import net.corda.libs.configuration.SmartConfig
import net.corda.messaging.api.records.Record
import org.osgi.service.component.annotations.Activate
import org.osgi.service.component.annotations.Component
import org.osgi.service.component.annotations.Reference
import org.slf4j.LoggerFactory

@Suppress("Unused", "TooManyFunctions")
@Component(service = [FlowEventExceptionProcessor::class])
class FlowEventExceptionProcessorImpl @Activate constructor(
    @Reference(service = FlowMessageFactory::class)
    private val flowMessageFactory: FlowMessageFactory,
    @Reference(service = FlowRecordFactory::class)
    private val flowRecordFactory: FlowRecordFactory,
    @Reference(service = FlowSessionManager::class)
    private val flowSessionManager: FlowSessionManager,
    @Reference(service = FlowFiberCache::class)
    private val flowFiberCache: FlowFiberCache,
    @Reference(service = CheckpointCleanupHandler::class)
    private val checkpointCleanupHandler: CheckpointCleanupHandler
) : FlowEventExceptionProcessor {

    private companion object {
        private val log = LoggerFactory.getLogger(this::class.java.enclosingClass)
    }

    override fun configure(config: SmartConfig) {
    }

    override fun process(throwable: Throwable, context: FlowEventContext<*>): FlowEventContext<*> {
        log.warn("Unexpected exception while processing flow, the flow will be sent to the DLQ", throwable)
        context.checkpoint.markDeleted()
        val metaWithTermination = addTerminationKeyToMeta(context.metadata)
        return context.copy(
            outputRecords = listOf(),
            sendToDlq = true,
            metadata = metaWithTermination
        )
    }

    override fun process(
        exception: FlowFatalException,
        context: FlowEventContext<*>
    ): FlowEventContext<*> = withEscalation(context) {
        val checkpoint = context.checkpoint

        val msg = if (!checkpoint.doesExist) {
            "Flow processing for flow ID ${checkpoint.flowId} has failed due to a fatal exception. " +
                    "Checkpoint/Flow start context doesn't exist"
        } else {
            "Flow processing for flow ID ${checkpoint.flowId} has failed due to a fatal exception. " +
                    "Flow start context: ${checkpoint.flowStartContext}"
        }
        log.warn(msg, exception)

        removeCachedFlowFiber(checkpoint)
        val cleanupRecords = checkpointCleanupHandler.cleanupCheckpoint(checkpoint, context.flowConfig, exception)

        val metaWithTermination = addTerminationKeyToMeta(context.metadata)
        context.copy(
            outputRecords = cleanupRecords,
            sendToDlq = true,
            metadata = metaWithTermination
        )
    }

    private fun createStatusRecord(id: String, statusGenerator: () -> FlowStatus): List<Record<*, *>> {
        return try {
            val status = statusGenerator()
            listOf(flowRecordFactory.createFlowStatusRecord(status))
        } catch (e: IllegalStateException) {
            // Most errors should happen after a flow has been initialised. However, it is possible for
            // initialisation to have not yet happened at the point the failure is hit if it's a session init message
            // and something goes wrong in trying to retrieve the sandbox. In this case we cannot update the status
            // correctly. This shouldn't matter however - in this case we're treating the issue as the flow never
            // starting at all. We'll still log that the error was seen.
            log.warn(
                "Could not create a flow status message for a flow with ID $id as the flow start context was missing."
            )
            listOf()
        }
    }

    override fun process(
        exception: FlowEventException,
        context: FlowEventContext<*>
    ): FlowEventContext<*> = withEscalation(context) {
        log.warn("A non critical error was reported while processing the event: ${exception.message}")

        removeCachedFlowFiber(context.checkpoint)

        context
    }

    override fun process(
        exception: FlowPlatformException,
        context: FlowEventContext<*>
    ): FlowEventContext<*> {
        return withEscalation(context) {
            val checkpoint = context.checkpoint

            checkpoint.setPendingPlatformError(PLATFORM_ERROR, exception.message)
            checkpoint.waitingFor = WaitingFor(net.corda.data.flow.state.waiting.Wakeup())

            removeCachedFlowFiber(checkpoint)

            context
        }
    }

    override fun process(
        exception: FlowMarkedForKillException,
        context: FlowEventContext<*>
    ): FlowEventContext<*> {
        return withEscalation(context) {
            val checkpoint = context.checkpoint

            removeCachedFlowFiber(checkpoint)
            val cleanupRecords = checkpointCleanupHandler.cleanupCheckpoint(checkpoint, context.flowConfig, exception)
            val metaWithTermination = addTerminationKeyToMeta(context.metadata)
            context.copy(
                outputRecords =  cleanupRecords,
                sendToDlq = false, // killed flows do not go to DLQ
                metadata = metaWithTermination
            )
        }
    }

    private fun withEscalation(context: FlowEventContext<*>, handler: () -> FlowEventContext<*>): FlowEventContext<*> {
        return try {
            handler()
        } catch (t: Throwable) {
            // The exception handler failed. Rather than take the whole pipeline down, forcibly DLQ the offending event.
            process(t, context)
        }
    }

    private fun createFlowKilledStatusRecord(checkpoint: FlowCheckpoint, message: String?): List<Record<*, *>> {
        return createStatusRecord(checkpoint.flowId) {
            flowMessageFactory.createFlowKilledStatusMessage(checkpoint, message)
        }
    }

    /**
     * Remove cached flow fiber for this checkpoint, if it exists.
     */
    private fun removeCachedFlowFiber(checkpoint: FlowCheckpoint) {
        if (checkpoint.doesExist) flowFiberCache.remove(checkpoint.flowKey)
    }
}
