"""Implements transformers raise time series to user provided exponent."""

__maintainer__ = []
__all__ = ["ExponentTransformer"]

from warnings import warn

import numpy as np
import pandas as pd
from deprecated.sphinx import deprecated

from aeon.transformations.base import BaseTransformer


# TODO: remove in v0.11.0
@deprecated(
    version="0.10.0",
    reason="ExponentTransformer will be removed in version 0.11.0.",
    category=FutureWarning,
)
class ExponentTransformer(BaseTransformer):
    """Apply element-wise exponentiation transformation to a time series.

    Transformation performs the following operations element-wise:
        * adds the constant `offset` (shift)
        * raises to the `power` provided (exponentiation)
    Offset="auto" computes offset as the smallest offset that ensure all elements
    are non-negative before exponentiation.

    Parameters
    ----------
    power : int or float, default=0.5
        The power to raise the input timeseries to.

    offset : "auto", int or float, default="auto"
        Offset to be added to the input timeseries prior to raising
        the timeseries to the given `power`. If "auto" the series is checked to
        determine if it contains negative values. If negative values are found
        then the offset will be equal to the absolute value of the most negative
        value. If not negative values are present the offset is set to zero.
        If an integer or float value is supplied it will be used as the offset.

    Attributes
    ----------
    power : int or float
        User supplied power.

    offset : int or float, or iterable.
        User supplied offset value.
        Scalar or 1D iterable with as many values as X columns in transform.


    Notes
    -----
    For an input series `Z` the exponent transformation is defined as
    :math:`(Z + offset)^{power}`.

    Examples
    --------
    >>> from aeon.transformations.exponent import ExponentTransformer
    >>> from aeon.datasets import load_airline
    >>> y = load_airline()
    >>> transformer = ExponentTransformer()
    >>> y_transform = transformer.fit_transform(y)
    """

    _tags = {
        "input_data_type": "Series",
        # what is the abstract type of X: Series, or Panel
        "output_data_type": "Series",
        # what abstract type is returned: Primitives, Series, Panel
        "instancewise": True,  # is this an instance-wise transform?
        "X_inner_type": ["pd.DataFrame", "pd.Series"],
        "y_inner_type": "None",
        "fit_is_empty": True,
        "transform-returns-same-time-index": True,
        "capability:multivariate": True,
        "capability:inverse_transform": True,
    }

    def __init__(self, power=0.5, offset="auto"):
        self.power = power
        self.offset = offset

        if not isinstance(self.power, (int, float)):
            raise ValueError(
                f"Expected `power` to be int or float, but found {type(self.power)}."
            )

        offset_types = (int, float, pd.Series, np.ndarray)
        if not isinstance(offset, offset_types) and offset != "auto":
            raise ValueError(
                f"Expected `offset` to be int or float, but found {type(self.offset)}."
            )

        super().__init__()

        if abs(power) < 1e-6:
            warn(
                "power close to zero passed to ExponentTransformer, "
                "inverse_transform will default to identity "
                "if called, in order to avoid division by zero"
            )
            self.set_tags(**{"skip-inverse-transform": True})

    def _transform(self, X, y=None):
        """Transform X and return a transformed version.

        private _transform containing the core logic, called from transform

        Parameters
        ----------
        X : pd.Series or pd.DataFrame
            Data to be transformed
        y : ignored argument for interface compatibility
            Additional data, e.g., labels for transformation

        Returns
        -------
        Xt : pd.Series or pd.DataFrame, same type as X
            transformed version of X
        """
        offset = self._get_offset(X)
        Xt = X.add(offset).pow(self.power)
        return Xt

    def _inverse_transform(self, X, y=None):
        """Logic used by `inverse_transform` to reverse transformation on `X`.

        Parameters
        ----------
        X : pd.Series or pd.DataFrame
            Data to be inverse transformed
        y : ignored argument for interface compatibility
            Additional data, e.g., labels for transformation

        Returns
        -------
        Xt : pd.Series or pd.DataFrame, same type as X
            inverse transformed version of X
        """
        offset = self._get_offset(X)
        Xt = X.pow(1.0 / self.power).add(-offset)
        return Xt

    def _get_offset(self, X):
        if self.offset == "auto":
            Xmin = X.min()
            offset = -Xmin * (Xmin < 0)
        else:
            offset = self.offset

        if isinstance(X, pd.DataFrame):
            if isinstance(offset, (int, float)):
                offset = pd.Series(offset, index=X.columns)
            else:
                offset = pd.Series(offset)
                offset.index = X.columns

        return offset

    @classmethod
    def get_test_params(cls, parameter_set="default"):
        """Return testing parameter settings for the estimator.

        Parameters
        ----------
        parameter_set : str, default="default"
            Name of the set of test parameters to return, for use in tests. If no
            special parameters are defined for a value, will return `"default"` set.
            There are currently no reserved values for transformers.

        Returns
        -------
        params : dict or list of dict, default = {}
            Parameters to create testing instances of the class
            Each dict are parameters to construct an "interesting" test instance, i.e.,
            `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
            `create_test_instance` uses the first (or only) dictionary in `params`
        """
        return [{"power": 2.5, "offset": 1}]
