{-# LANGUAGE RecordWildCards #-}

-- This file is part of the Wire Server implementation.
--
-- Copyright (C) 2025 Wire Swiss GmbH <opensource@wire.com>
--
-- This program is free software: you can redistribute it and/or modify it under
-- the terms of the GNU Affero General Public License as published by the Free
-- Software Foundation, either version 3 of the License, or (at your option) any
-- later version.
--
-- This program is distributed in the hope that it will be useful, but WITHOUT
-- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
-- FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
-- details.
--
-- You should have received a copy of the GNU Affero General Public License along
-- with this program. If not, see <https://www.gnu.org/licenses/>.

module Network.AMQP.Extended
  ( RabbitMqHooks (..),
    RabbitMqAdminOpts (..),
    AmqpEndpoint (..),
    openConnectionWithRetries,
    mkRabbitMqAdminClientEnv,
    mkRabbitMqAdminClientEnvWithCreds,
    mkRabbitMqChannelMVar,
    demoteOpts,
    RabbitMqTlsOpts (..),
    mkConnectionOpts,
    mkTLSSettings,
    readCredsFromEnv,
  )
where

import Control.Exception (AsyncException, throwIO)
import Control.Monad.Catch
import Control.Monad.Trans.Control
import Control.Monad.Trans.Maybe
import Control.Retry
import Data.Aeson
import Data.Aeson.Types
import Data.Default
import Data.Proxy
import Data.Text qualified as Text
import Data.Text.Encoding qualified as Text
import Data.X509.CertificateStore qualified as X509
import Imports
import Network.AMQP qualified as Q
import Network.Connection as Conn
import Network.HTTP.Client qualified as HTTP
import Network.HTTP.Client.TLS qualified as HTTP
import Network.RabbitMqAdmin
import Network.TLS
import Network.TLS.Extra.Cipher
import Servant hiding (Handler)
import Servant.Client
import Servant.Client qualified as Servant
import System.Logger (Logger)
import System.Logger qualified as Log
import UnliftIO.Async

data RabbitMqHooks m = RabbitMqHooks
  { -- | Called whenever there is a new channel. At any time there should be at
    -- max 1 open channel. Perhaps this would need to change in future.
    onNewChannel :: Q.Channel -> m (),
    -- | Called when connection is closed. Any exceptions thrown by this would
    -- be logged and ignored.
    onConnectionClose :: m (),
    -- | Called when the channel is closed. Any exceptions thrown by this would
    -- be logged and ignored.
    onChannelException :: SomeException -> m ()
  }

data RabbitMqTlsOpts = RabbitMqTlsOpts
  { caCert :: !(Maybe FilePath),
    insecureSkipVerifyTls :: Bool
  }
  deriving (Eq, Show)

parseTlsJson :: Object -> Parser (Maybe RabbitMqTlsOpts)
parseTlsJson v = do
  enabled <- v .:? "enableTls" .!= False
  if enabled
    then
      Just
        <$> ( RabbitMqTlsOpts
                <$> v .:? "caCert"
                <*> v .:? "insecureSkipVerifyTls" .!= False
            )
    else pure Nothing

data RabbitMqAdminOpts = RabbitMqAdminOpts
  { host :: !String,
    port :: !Int,
    vHost :: !Text,
    tls :: Maybe RabbitMqTlsOpts,
    adminHost :: !String,
    adminPort :: !Int
  }
  deriving (Eq, Show)

instance FromJSON RabbitMqAdminOpts where
  parseJSON = withObject "RabbitMqAdminOpts" $ \v ->
    RabbitMqAdminOpts
      <$> v .: "host"
      <*> v .: "port"
      <*> v .: "vHost"
      <*> parseTlsJson v
      <*> v .: "adminHost"
      <*> v .: "adminPort"

mkRabbitMqAdminClientEnvWithCreds :: RabbitMqAdminOpts -> Text -> Text -> IO (AdminAPI (AsClientT IO))
mkRabbitMqAdminClientEnvWithCreds opts username password = do
  mTlsSettings <- traverse (mkTLSSettings opts.host) opts.tls
  let (protocol, managerSettings) = case mTlsSettings of
        Nothing -> (Servant.Http, HTTP.defaultManagerSettings)
        Just tlsSettings -> (Servant.Https, HTTP.mkManagerSettings tlsSettings Nothing)
  manager <- HTTP.newManager managerSettings
  let basicAuthData = Servant.BasicAuthData (Text.encodeUtf8 username) (Text.encodeUtf8 password)
      clientEnv = Servant.mkClientEnv manager (Servant.BaseUrl protocol opts.adminHost opts.adminPort "")
  pure . fromServant $
    hoistClient
      (Proxy @(ToServant AdminAPI AsApi))
      (either throwM pure <=< flip runClientM clientEnv)
      (toServant $ adminClient basicAuthData)

mkRabbitMqAdminClientEnv :: RabbitMqAdminOpts -> IO (AdminAPI (AsClientT IO))
mkRabbitMqAdminClientEnv opts = readCredsFromEnv >>= uncurry (mkRabbitMqAdminClientEnvWithCreds opts)

-- | When admin opts are needed use `AmqpEndpoint Identity`, otherwise use
-- `AmqpEndpoint NoAdmin`.
data AmqpEndpoint = AmqpEndpoint
  { host :: !String,
    port :: !Int,
    vHost :: !Text,
    tls :: !(Maybe RabbitMqTlsOpts)
  }
  deriving (Eq, Show)

instance FromJSON AmqpEndpoint where
  parseJSON = withObject "AmqpEndpoint" $ \v ->
    AmqpEndpoint
      <$> v .: "host"
      <*> v .: "port"
      <*> v .: "vHost"
      <*> parseTlsJson v

demoteOpts :: RabbitMqAdminOpts -> AmqpEndpoint
demoteOpts RabbitMqAdminOpts {..} = AmqpEndpoint {..}

-- | Useful if the application only pushes into some queues.
mkRabbitMqChannelMVar :: Logger -> Maybe Text -> AmqpEndpoint -> IO (MVar Q.Channel)
mkRabbitMqChannelMVar l connName opts = do
  chanMVar <- newEmptyMVar
  connThread <-
    async . openConnectionWithRetries l opts connName $
      RabbitMqHooks
        { onNewChannel = \conn -> putMVar chanMVar conn >> forever (threadDelay maxBound),
          onChannelException = \_ -> void $ tryTakeMVar chanMVar,
          onConnectionClose = void $ tryTakeMVar chanMVar
        }
  waitForConnThread <- async $ withMVar chanMVar $ \_ -> pure ()
  waitEither connThread waitForConnThread >>= \case
    Left () -> throwIO $ RabbitMqConnectionFailed "connection thread finished before getting connection"
    Right () -> pure chanMVar

data RabbitMqConnectionError = RabbitMqConnectionFailed String
  deriving (Show)

instance Exception RabbitMqConnectionError

mkConnectionOpts :: (MonadIO m) => AmqpEndpoint -> Maybe Text -> m Q.ConnectionOpts
mkConnectionOpts AmqpEndpoint {..} name = do
  mTlsSettings <- traverse (liftIO . (mkTLSSettings host)) tls
  (username, password) <- liftIO $ readCredsFromEnv
  pure
    Q.defaultConnectionOpts
      { Q.coServers = [(host, fromIntegral port)],
        Q.coVHost = vHost,
        Q.coAuth = [Q.plain username password],
        Q.coTLSSettings = fmap Q.TLSCustom mTlsSettings,
        Q.coName = name
      }

-- | Connects with RabbitMQ and opens a channel. If the channel is closed for
-- some reasons, reopens the channel. If the connection is closed for some
-- reasons, keeps retrying to connect until it works.
openConnectionWithRetries ::
  forall m.
  (MonadIO m, MonadMask m, MonadBaseControl IO m) =>
  Logger ->
  AmqpEndpoint ->
  Maybe Text ->
  RabbitMqHooks m ->
  m ()
openConnectionWithRetries l AmqpEndpoint {..} connName hooks = do
  (username, password) <- liftIO $ readCredsFromEnv
  connectWithRetries username password
  where
    connectWithRetries :: Text -> Text -> m ()
    connectWithRetries username password = do
      -- Jittered exponential backoff with 1ms as starting delay and 5s as max
      -- delay.
      let policy = capDelay 5_000_000 $ fullJitterBackoff 1000
          logError willRetry e retryStatus = do
            Log.err l $
              Log.msg (Log.val "Failed to connect to RabbitMQ")
                . Log.field "error" (displayException @SomeException e)
                . Log.field "willRetry" willRetry
                . Log.field "retryCount" retryStatus.rsIterNumber
          getConn = do
            Log.info l $ Log.msg (Log.val "About to enter recovering...")
            conn <-
              recovering
                policy
                ( logAndSkipAsyncExceptions l
                    <> [logRetries (const $ pure True) logError]
                )
                ( const $ do
                    Log.info l $ Log.msg (Log.val "Trying to connect to RabbitMQ")
                    connOpts <- mkConnectionOpts AmqpEndpoint {..} connName
                    liftIO $ Q.openConnection'' connOpts
                )
            Log.info l $ Log.msg (Log.val "Retrieved connection...")
            pure conn
      bracket getConn (liftIO . Q.closeConnection) $ \conn -> do
        liftBaseWith $ \runInIO ->
          Q.addConnectionClosedHandler conn True $ void $ runInIO $ do
            hooks.onConnectionClose
              `catch` logException l "onConnectionClose hook threw an exception, reconnecting to RabbitMQ anyway"
            connectWithRetries username password
        openChan conn

    openChan :: Q.Connection -> m ()
    openChan conn = do
      Log.info l $ Log.msg (Log.val "Opening channel with RabbitMQ")
      chan <- liftIO $ Q.openChannel conn
      liftBaseWith $ \runInIO ->
        Q.addChannelExceptionHandler chan (void . runInIO . chanExceptionHandler conn)
      Log.info l $ Log.msg (Log.val "RabbitMQ channel opened")
      hooks.onNewChannel chan

    chanExceptionHandler :: Q.Connection -> SomeException -> m ()
    chanExceptionHandler conn e = do
      hooks.onChannelException e `catch` logException l "onChannelException hook threw an exception"
      case (Q.isNormalChannelClose e, fromException e) of
        (True, _) ->
          Log.info l $
            Log.msg (Log.val "RabbitMQ channel is closed normally, not attempting to reopen channel")
        (_, Just (Q.ConnectionClosedException {})) ->
          Log.info l $
            Log.msg (Log.val "RabbitMQ connection is closed, not attempting to reopen channel")
        _ -> do
          logException l "RabbitMQ channel closed" e
          openChan conn

-- | List of pre-made handlers that will skip retries on
-- 'AsyncException' and 'SomeAsyncException' and log them.
-- See also `Control.Retry.skipAsyncExceptions`
logAndSkipAsyncExceptions :: (MonadIO m) => Logger -> [RetryStatus -> Control.Monad.Catch.Handler m Bool]
logAndSkipAsyncExceptions l = handlers
  where
    asyncH _ = Handler $ \(e :: AsyncException) -> do
      logException l "AsyncException caught" (SomeException e)
      pure False
    someAsyncH _ = Handler $ \(e :: SomeAsyncException) -> do
      logException l "SomeAsyncException caught" (SomeException e)
      pure False
    handlers = [asyncH, someAsyncH]

mkTLSSettings :: HostName -> RabbitMqTlsOpts -> IO TLSSettings
mkTLSSettings host opts = do
  setCAStore <- runMaybeT $ do
    path <- maybe mzero pure opts.caCert
    store <- MaybeT $ X509.readCertificateStore path
    pure $ \shared -> shared {sharedCAStore = store}
  let setHooks =
        if opts.insecureSkipVerifyTls
          then \h -> h {onServerCertificate = \_ _ _ _ -> pure []}
          else id
  pure $
    TLSSettings
      (defaultParamsClient host "rabbitmq")
        { clientShared = fromMaybe id setCAStore def,
          clientHooks = setHooks def,
          clientSupported =
            def
              { supportedVersions = [TLS13, TLS12],
                supportedCiphers = ciphersuite_strong
              }
        }

logException :: (MonadIO m) => Logger -> String -> SomeException -> m ()
logException l m (SomeException e) = do
  Log.err l $
    Log.msg m
      . Log.field "error" (displayException e)

readCredsFromEnv :: IO (Text, Text)
readCredsFromEnv =
  (,)
    <$> (Text.pack <$> getEnv "RABBITMQ_USERNAME")
    <*> (Text.pack <$> getEnv "RABBITMQ_PASSWORD")
