package net.corda.messaging.mediator.factory


import net.corda.messaging.api.mediator.MediatorConsumer
import net.corda.messaging.api.mediator.MediatorInputService
import net.corda.messaging.api.mediator.MessageRouter
import net.corda.messaging.api.mediator.MessagingClient
import net.corda.messaging.api.mediator.config.EventMediatorConfig
import net.corda.messaging.api.mediator.config.MediatorConsumerConfig
import net.corda.messaging.api.mediator.config.MessagingClientConfig
import net.corda.messaging.api.mediator.factory.MediatorConsumerFactory
import net.corda.messaging.api.mediator.factory.MessageRouterFactory
import net.corda.messaging.api.mediator.factory.MessagingClientFactory
import net.corda.messaging.api.mediator.factory.MessagingClientFinder
import net.corda.messaging.api.processor.StateAndEventProcessor
import net.corda.messaging.api.processor.StateAndEventProcessor.State
import net.corda.messaging.api.records.Record
import net.corda.messaging.mediator.GroupAllocator
import net.corda.messaging.mediator.MediatorSubscriptionState
import net.corda.messaging.mediator.StateManagerHelper
import net.corda.taskmanager.TaskManager
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertNotNull
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
import org.mockito.Mockito
import org.mockito.kotlin.any
import org.mockito.kotlin.argumentCaptor
import org.mockito.kotlin.doReturn
import org.mockito.kotlin.mock
import org.mockito.kotlin.verify
import org.mockito.kotlin.whenever
import java.util.concurrent.atomic.AtomicBoolean

class MediatorComponentFactoryTest {
    private lateinit var mediatorComponentFactory: MediatorComponentFactory<String, String, String>
    private val messageProcessor = object : StateAndEventProcessor<String, String, String> {
        override fun onNext(
            state: State<String>?, event: Record<String, String>
        ): StateAndEventProcessor.Response<String> {
            TODO("Not yet implemented")
        }

        override val keyClass get() = String::class.java
        override val stateValueClass get() = String::class.java
        override val eventValueClass get() = String::class.java

    }
    private val consumerFactories = listOf(
        mock<MediatorConsumerFactory>(),
        mock<MediatorConsumerFactory>(),
    )
    private val clientFactories = listOf(
        mock<MessagingClientFactory>(),
        mock<MessagingClientFactory>(),
    )
    private val messageRouterFactory = mock<MessageRouterFactory>()
    private val groupAllocator = mock<GroupAllocator>()
    private val stateManagerHelper = mock<StateManagerHelper<String>>()
    private val taskManager = mock<TaskManager>()
    private val messageRouter = mock<MessageRouter>()
    private val mediatorInputService = mock<MediatorInputService>()
    private val mediatorSubscriptionState = MediatorSubscriptionState(AtomicBoolean(false), AtomicBoolean(false))
    private val eventMediatorConfig = mock<EventMediatorConfig<String, String, String>>().apply {
        whenever(name).thenReturn("name")
        whenever(stateManager).thenReturn(mock())
    }

    @BeforeEach
    fun beforeEach() {
        consumerFactories.forEach {
            doReturn(mock<MediatorConsumer<String, String>>()).`when`(it).create(
                any<MediatorConsumerConfig<String, String>>()
            )
        }

        clientFactories.forEach {
            doReturn(mock<MessagingClient>()).`when`(it).create(
                any<MessagingClientConfig>()
            )
        }

        doReturn(mock<MessageRouter>()).`when`(messageRouterFactory).create(
            any<MessagingClientFinder>()
        )

        mediatorComponentFactory = MediatorComponentFactory(
            messageProcessor,
            consumerFactories,
            clientFactories,
            messageRouterFactory,
            groupAllocator,
            stateManagerHelper,
            mediatorInputService
        )
    }

    @Test
    fun `successfully creates consumers`() {
        val onSerializationError: (ByteArray) -> Unit = {}

        val mediatorConsumers = mediatorComponentFactory.createConsumers(onSerializationError)

        assertEquals(consumerFactories.size, mediatorConsumers.size)
        mediatorConsumers.forEach {
            assertNotNull(it)
        }

        consumerFactories.forEach {
            val consumerConfigCaptor = argumentCaptor<MediatorConsumerConfig<String, String>>()
            verify(it).create(consumerConfigCaptor.capture())
            val consumerConfig = consumerConfigCaptor.firstValue
            assertEquals(String::class.java, consumerConfig.keyClass)
            assertEquals(String::class.java, consumerConfig.valueClass)
            assertEquals(onSerializationError, consumerConfig.onSerializationError)
        }
    }

    @Test
    fun `throws exception when consumer factory not provided`() {
        val mediatorComponentFactory = MediatorComponentFactory(
            messageProcessor,
            emptyList(),
            clientFactories,
            messageRouterFactory,
            groupAllocator,
            stateManagerHelper,
            mediatorInputService
        )

        assertThrows<IllegalStateException> {
            mediatorComponentFactory.createConsumers { }
        }
    }

    @Test
    fun `successfully creates clients`() {
        val onSerializationError: (ByteArray) -> Unit = {}

        val mediatorClients = mediatorComponentFactory.createClients(onSerializationError)

        assertEquals(clientFactories.size, mediatorClients.size)
        mediatorClients.forEach {
            assertNotNull(it)
        }

        clientFactories.forEach {
            val clientConfigCaptor = argumentCaptor<MessagingClientConfig>()
            verify(it).create(clientConfigCaptor.capture())
            val clientConfig = clientConfigCaptor.firstValue
            assertEquals(onSerializationError, clientConfig.onSerializationError)
        }
    }

    @Test
    fun `throws exception when client factory not provided`() {
        val mediatorComponentFactory = MediatorComponentFactory(
            messageProcessor,
            consumerFactories,
            emptyList(),
            messageRouterFactory,
            groupAllocator,
            stateManagerHelper,
            mediatorInputService
        )

        assertThrows<IllegalStateException> {
            mediatorComponentFactory.createClients { }
        }
    }

    @Test
    fun `successfully creates message router`() {
        val clients = listOf(
            mock<MessagingClient>(),
            mock<MessagingClient>(),
        )
        clients.forEachIndexed { id, client ->
            Mockito.doReturn(id.toString()).whenever(client).id
        }

        val messageRouter = mediatorComponentFactory.createRouter(clients)

        assertNotNull(messageRouter)

        val messagingClientFinderCaptor = argumentCaptor<MessagingClientFinder>()
        verify(messageRouterFactory).create(messagingClientFinderCaptor.capture())
        val messagingClientFinder = messagingClientFinderCaptor.firstValue

        clients.forEachIndexed { id, client ->
            assertEquals(client, messagingClientFinder.find(id.toString()))
        }
        assertThrows<IllegalStateException> {
            messagingClientFinder.find("unknownId")
        }
    }

    @Test
    fun `create a consumer processor`() {
        val consumerProcessor = mediatorComponentFactory.createConsumerProcessor(eventMediatorConfig, taskManager, messageRouter,
        mediatorSubscriptionState)

        assertThat(consumerProcessor).isNotNull()
    }
}