import numpy as np
from sklearn.linear_model import RidgeCV
from sklearn.pipeline import make_pipeline

from aeon.regression import BaseRegressor
from aeon.transformations.collection.convolution_based._hydra import HydraTransformer


class HydraRegressor(BaseRegressor):
    """Hydra Regressor.

    The algorithm utilises convolutional kernels grouped into ``g`` groups per dilation
    with ``k`` kernels per group. It transforms input time series using these kernels
    and counts the kernels representing the closest match to the input at each time
    point. This counts for each group are then concatenated and used to train a linear
    regressor.

    The algorithm combines aspects of both Rocket (convolutional approach)
    and traditional dictionary methods (pattern counting), It extracts features from
    both the base series and first-order differences of the series.

    Parameters
    ----------
    n_kernels : int, default=8
        Number of kernels per group.
    n_groups : int, default=64
        Number of groups per dilation.
    n_jobs : int, default=1
        The number of jobs to run in parallel for both `fit` and `predict`.
        ``-1`` means using all processors.
    random_state : int, RandomState instance or None, default=None
        If `int`, random_state is the seed used by the random number generator;
        If `RandomState` instance, random_state is the random number generator;
        If `None`, the random number generator is the `RandomState` instance used
        by `np.random`.

    See Also
    --------
    HydraTransformer
    MultiRocketHydraRegressor

    Notes
    -----
    Original code: https://github.com/angus924/hydra

    References
    ----------
    .. [1] Dempster, A., Schmidt, D.F. and Webb, G.I., 2023. Hydra: Competing
        convolutional kernels for fast and accurate time series classification.
        Data Mining and Knowledge Discovery, pp.1-27.

    Examples
    --------
    >>> from aeon.regression.convolution_based import HydraRegressor
    >>> from aeon.testing.utils.data_gen import make_example_3d_numpy
    >>> X, y = make_example_3d_numpy(n_cases=10, n_channels=1, n_timepoints=12,
    ...                              regression_target=True, random_state=0)
    >>> clf = HydraRegressor(random_state=0)  # doctest: +SKIP
    >>> clf.fit(X, y)  # doctest: +SKIP
    HydraRegressor(random_state=0)
    >>> clf.predict(X)  # doctest: +SKIP
    array([0, 1, 0, 1, 0, 0, 1, 1, 1, 0])
    """

    _tags = {
        "capability:multivariate": True,
        "capability:multithreading": True,
        "algorithm_type": "convolution",
        "python_dependencies": "torch",
    }

    def __init__(self, n_kernels=8, n_groups=64, n_jobs=1, random_state=None):
        self.n_kernels = n_kernels
        self.n_groups = n_groups
        self.n_jobs = n_jobs
        self.random_state = random_state

        super().__init__()

    def _fit(self, X, y):
        transform = HydraTransformer(
            n_kernels=self.n_kernels,
            n_groups=self.n_groups,
            n_jobs=self.n_jobs,
            random_state=self.random_state,
        )

        self._clf = make_pipeline(
            transform,
            _SparseScaler(),
            RidgeCV(alphas=np.logspace(-3, 3, 10)),
        )
        self._clf.fit(X, y)

        return self

    def _predict(self, X) -> np.ndarray:
        return self._clf.predict(X)


class _SparseScaler:
    """Sparse Scaler for hydra transform."""

    def __init__(self, mask=True, exponent=4):
        self.mask = mask
        self.exponent = exponent

    def fit(self, X, y=None):
        X = X.clamp(0).sqrt()

        self.epsilon = (X == 0).float().mean(0) ** self.exponent + 1e-8

        self.mu = X.mean(0)
        self.sigma = X.std(0) + self.epsilon

    def transform(self, X, y=None):
        X = X.clamp(0).sqrt()

        if self.mask:
            return ((X - self.mu) * (X != 0)) / self.sigma
        else:
            return (X - self.mu) / self.sigma

    def fit_transform(self, X, y=None):
        self.fit(X)
        return self.transform(X)
