import type { CancelReason } from '@vitest/runner'
import type { BirpcOptions, BirpcReturn } from 'birpc'
import type { RunnerRPC, RuntimeRPC } from '../types/rpc'
import type { WorkerRPC } from '../types/worker'
import { getSafeTimers } from '@vitest/utils/timers'
import { createBirpc } from 'birpc'
import { getWorkerState } from './utils'

const { get } = Reflect

function withSafeTimers(fn: () => void) {
  const { setTimeout, clearTimeout, nextTick, setImmediate, clearImmediate }
    = getSafeTimers()

  const currentSetTimeout = globalThis.setTimeout
  const currentClearTimeout = globalThis.clearTimeout
  const currentSetImmediate = globalThis.setImmediate
  const currentClearImmediate = globalThis.clearImmediate

  const currentNextTick = globalThis.process?.nextTick

  try {
    globalThis.setTimeout = setTimeout
    globalThis.clearTimeout = clearTimeout

    if (setImmediate) {
      globalThis.setImmediate = setImmediate
    }
    if (clearImmediate) {
      globalThis.clearImmediate = clearImmediate
    }

    if (globalThis.process && nextTick) {
      globalThis.process.nextTick = nextTick
    }

    const result = fn()
    return result
  }
  finally {
    globalThis.setTimeout = currentSetTimeout
    globalThis.clearTimeout = currentClearTimeout
    globalThis.setImmediate = currentSetImmediate
    globalThis.clearImmediate = currentClearImmediate

    if (globalThis.process && nextTick) {
      nextTick(() => {
        globalThis.process.nextTick = currentNextTick
      })
    }
  }
}

const promises = new Set<Promise<unknown>>()

export async function rpcDone(): Promise<unknown[] | undefined> {
  if (!promises.size) {
    return
  }
  const awaitable = Array.from(promises)
  return Promise.all(awaitable)
}

const onCancelCallbacks: ((reason: CancelReason) => void)[] = []

export function onCancel(callback: (reason: CancelReason) => void): void {
  onCancelCallbacks.push(callback)
}

export function createRuntimeRpc(
  options: Pick<
    BirpcOptions<RuntimeRPC>,
    'on' | 'post' | 'serialize' | 'deserialize'
  >,
): WorkerRPC {
  return createSafeRpc(
    createBirpc<RuntimeRPC, RunnerRPC>(
      {
        async onCancel(reason) {
          await Promise.all(onCancelCallbacks.map(fn => fn(reason)))
        },
      },
      {
        eventNames: [
          'onUserConsoleLog',
          'onCollected',
          'onCancel',
        ],
        timeout: -1,
        ...options,
      },
    ),
  )
}

export function createSafeRpc(rpc: WorkerRPC): WorkerRPC {
  return new Proxy(rpc, {
    get(target, p, handler) {
      // keep $rejectPendingCalls as sync function
      if (p === '$rejectPendingCalls') {
        return rpc.$rejectPendingCalls
      }

      const sendCall = get(target, p, handler)
      const safeSendCall = (...args: any[]) =>
        withSafeTimers(async () => {
          const result = sendCall(...args)
          promises.add(result)
          try {
            return await result
          }
          finally {
            promises.delete(result)
          }
        })
      safeSendCall.asEvent = sendCall.asEvent
      return safeSendCall
    },
  })
}

export function rpc(): BirpcReturn<RuntimeRPC, RunnerRPC> {
  const { rpc } = getWorkerState()
  return rpc
}
