"""Implement transformers for summarizing a time series."""

__maintainer__ = []
__all__ = ["SummaryTransformer", "WindowSummarizer"]

import numpy as np
import pandas as pd
from joblib import Parallel, delayed

from aeon.transformations.base import BaseTransformer
from aeon.utils.multiindex import flatten_multiindex


class WindowSummarizer(BaseTransformer):
    """
    Transformer for extracting time series features.

    The WindowSummarizer transforms input series to features based
    on a provided dictionary of window summarizer, window shifts
    and window lengths.

    Parameters
    ----------
    n_jobs : int, default=-1
        The number of jobs to run in parallel for applying the window functions.
        ``-1`` means using all processors.
    target_cols : list of str, optional (default = None)
        Specifies which columns in X to target for applying the window functions.
        ``None`` will target the first column
    lag_feature : dict of str and list, optional (default = dict containing first lag)
        Dictionary specifying as key the type of function to be used and as value
        the argument `window`.
        For the function `lag`, the argument `window` is an integer or a list of
        integers giving the `lag` values to be used.
        For all other functions, the argument `window` is a list with the arguments
        `lag` and `window length`. `lag` defines how far back in the past the window
        starts, `window length` gives the length of the window across which to apply the
        function. For multiple different windows, provide a list of lists.

        Please see below a graphical representation of the logic using the following
        symbols:

        ``z`` = time stamp that the window is summarized *to*.
        Part of the window if `lag` is between 0 and `1-window_length`, otherwise
        not part of the window.
        ``x`` = (other) time stamps in the window which is summarized
        ``*`` = observations, past or future, not part of the window

        The summarization function is applied to the window consisting of x and
        potentially z.

        For `window = [1, 3]`, we have a `lag` of 1 and
        `window_length` of 3 to target the three last days (exclusive z) that were
        observed. Summarization is done across windows like this:
        |-------------------------- |
        | * * * * * * * * x x x z * |
        |---------------------------|

        For `window = [0, 3]`, we have a `lag` of 0 and
        `window_length` of 3 to target the three last days (inclusive z) that
        were observed. Summarization is done across windows like this:
        |-------------------------- |
        | * * * * * * * * x x z * * |
        |---------------------------|

        Special case ´lag´: Since lags are frequently used and window length is
        redundant, you only need to provide a list of `lag` values.
        So `window = [1]` will result in the first lag:

        |-------------------------- |
        | * * * * * * * * * * x z * |
        |---------------------------|

        And `window = [1, 4]` will result in the first and fourth lag:

        |-------------------------- |
        | * * * * * * * x * * x z * |
        |---------------------------|

        key: either custom function call (to be
                provided by user) or str corresponding to native pandas window function:
                * "sum",
                * "mean",
                * "median",
                * "std",
                * "var",
                * "kurt",
                * "min",
                * "max",
                * "corr",
                * "cov",
                * "skew",
                * "sem"
                See also: https://pandas.pydata.org/docs/reference/window.html.
            The column generated will be named after the key provided, followed by the
            lag parameter and the window_length (if not a lag).
        second value (window): list of integers
            List containg lag and window_length parameters.
        truncate : str, optional default = None
            Defines how to deal with NAs that were created as a result of applying the
            functions in the lag_feature dict across windows that are longer than
            the remaining history of data.
            For example a lag config of [14, 7] cannot be fully applied for the first 20
            observations of the targeted column.
            A lag_feature of [[8, 14], [1, 28]] cannot be correctly applied for the
            first 21 resp. 28 observations of the targeted column. Possible values
            to deal with those NAs:
                * None
                * "bfill"
            None will keep the NAs generated, and would leave it for the user to choose
            an estimator that can correctly deal with observations with missing values,
            "bfill" will fill the NAs by carrying the first observation backwards.

    Attributes
    ----------
    truncate_start : int
        See section Parameters - truncate for a more detailed explanation of truncation
        as a result of applying windows of certain lengths across past observations.
        Truncate_start will give the maximum of observations that are filled with NAs
        across all arguments of the lag_feature when truncate is set to None.

    Returns
    -------
    X: pd.DataFrame
        Contains all transformed columns as well as non-transformed columns.
        The raw inputs to transformed columns will be dropped.
    self: reference to self
    """

    _tags = {
        "input_data_type": "Series",
        "output_data_type": "Series",
        "instancewise": True,
        "capability:inverse_transform": False,
        "transform_labels": False,
        "X_inner_type": [
            "pd-multiindex",
            "pd.DataFrame",
            "pd_multiindex_hier",
        ],
        "skip-inverse-transform": True,  # is inverse-transform skipped when called?
        "capability:multivariate": True,  # can the transformer handle multivariate X?
        "capability:missing_values": True,  # can estimator handle missing data?
        "X-y-must-have-same-index": False,  # can estimator handle different X/y index?
        "enforce_index_type": None,  # index type that needs to be enforced in X/y
        "fit_is_empty": False,  # is fit empty and can be skipped? Yes = True
        "transform-returns-same-time-index": False,
        # does transform return have the same time index as input X
        "remember_data": True,  # remember all data seen as _X
    }

    def __init__(
        self,
        lag_feature=None,
        n_jobs=-1,
        target_cols=None,
        truncate=None,
    ):
        self.lag_feature = lag_feature
        self.n_jobs = n_jobs
        self.target_cols = target_cols
        self.truncate = truncate

        super().__init__()

    def _fit(self, X, y=None):
        """Fit transformer to X and y.

        Private _fit containing the core logic, called from fit

        Attributes
        ----------
        truncate_start : int
            See section class WindowSummarizer - Parameters - truncate for a more
            detailed explanation of truncation as a result of applying windows of
            certain lengths across past observations.
            Truncate_start will give the maximum of observations that are filled
            with NAs across all arguments of the lag_feature when truncate is
            set to None.

        Returns
        -------
        X: pd.DataFrame
            Contains all transformed columns as well as non-transformed columns.
            The raw inputs to transformed columns will be dropped.
        self: reference to self
        """
        X_name = get_name_list(X)

        if self.target_cols is not None:
            if not all(x in X_name for x in self.target_cols):
                missing_cols = [x for x in self.target_cols if x not in X_name]
                raise ValueError(
                    "target_cols "
                    + " ".join(missing_cols)
                    + " specified that do not exist in X."
                )

        if self.target_cols is None:
            self._target_cols = [X_name[0]]
        else:
            self._target_cols = self.target_cols

        # Convert lag config dictionary to pandas dataframe
        if self.lag_feature is None:
            func_dict = pd.DataFrame(
                {
                    "lag": [1],
                }
            ).T.reset_index()
        else:
            func_dict = pd.DataFrame.from_dict(
                self.lag_feature, orient="index"
            ).reset_index()

        func_dict = pd.melt(
            func_dict, id_vars="index", value_name="window", ignore_index=False
        )
        func_dict.sort_index(inplace=True)
        func_dict.drop("variable", axis=1, inplace=True)
        func_dict.rename(
            columns={"index": "summarizer"},
            inplace=True,
        )
        func_dict = func_dict.dropna(axis=0, how="any")
        # Identify lags (since they can follow special notation)
        lags = func_dict["summarizer"] == "lag"
        # Convert lags to default list notation with window_length 1
        boost_lag = func_dict.loc[lags, "window"].apply(lambda x: [int(x), 1])
        func_dict.loc[lags, "window"] = boost_lag
        self.truncate_start = func_dict["window"].apply(lambda x: x[0] + x[1] - 1).max()
        self._func_dict = func_dict

    def _transform(self, X, y=None):
        """Transform X and return a transformed version.

        Parameters
        ----------
        X : pd.DataFrame
        y : None

        Returns
        -------
        transformed version of X
        """
        idx = X.index
        X = X.combine_first(self._X)

        func_dict = self._func_dict
        target_cols = self._target_cols

        X.columns = X.columns.map(str)
        Xt_out = []
        if self.truncate == "bfill":
            bfill = True
        else:
            bfill = False
        for cols in target_cols:
            if isinstance(X.index, pd.MultiIndex):
                hier_levels = list(range(X.index.nlevels - 1))
                X_grouped = X.groupby(level=hier_levels)[cols]
                df = Parallel(n_jobs=self.n_jobs)(
                    delayed(_window_feature)(X_grouped, **kwargs, bfill=bfill)
                    for index, kwargs in func_dict.iterrows()
                )
            else:
                df = Parallel(n_jobs=self.n_jobs)(
                    delayed(_window_feature)(X.loc[:, [cols]], **kwargs, bfill=bfill)
                    for _index, kwargs in func_dict.iterrows()
                )
            Xt = pd.concat(df, axis=1)
            Xt = Xt.add_prefix(str(cols) + "_")
            Xt_out.append(Xt)
        Xt_out_df = pd.concat(Xt_out, axis=1)
        Xt_return = pd.concat([Xt_out_df, X.drop(target_cols, axis=1)], axis=1)

        Xt_return = Xt_return.loc[idx]
        return Xt_return

    @classmethod
    def get_test_params(cls, parameter_set="default"):
        """Return testing parameter settings for the estimator.

        Parameters
        ----------
        parameter_set : str, default="default"
            Name of the set of test parameters to return, for use in tests. If no
            special parameters are defined for a value, will return `"default"` set.

        Returns
        -------
        params : dict or list of dict, default = {}
            Parameters to create testing instances of the class
            Each dict are parameters to construct an "interesting" test instance, i.e.,
            `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
            `create_test_instance` uses the first (or only) dictionary in `params`
        """
        params1 = {
            "lag_feature": {
                "lag": [1],
                "mean": [[1, 3], [1, 12]],
                "std": [[1, 4]],
            }
        }

        params2 = {
            "lag_feature": {
                "lag": [3, 6],
            }
        }

        params3 = {
            "lag_feature": {
                "mean": [[1, 7], [8, 7]],
                "cov": [[1, 28]],
            }
        }

        return [params1, params2, params3]


# List of native pandas rolling window function.
# In the future different engines for pandas will be investigated
pd_rolling = [
    "sum",
    "mean",
    "median",
    "std",
    "var",
    "kurt",
    "min",
    "max",
    "corr",
    "cov",
    "skew",
    "sem",
]


def get_name_list(Z):
    """Get names of pd.Series or pd.Dataframe."""
    if isinstance(Z, pd.DataFrame):
        Z_name = Z.columns.to_list()
    else:
        if Z.name is not None:
            Z_name = [Z.name]
        else:
            Z_name = None
    Z_name = [str(z) for z in Z_name]
    return Z_name


def _window_feature(Z, summarizer=None, window=None, bfill=False):
    """Compute window features and lag.

    Apply summarizer passed over a certain window
    of past observations, e.g. the mean of a window of length 7 days, lagged by 14 days.

    Z: pandas Dataframe with a single column.
    name : str, base string of the derived features, will be appended by
        `lag` and window length parameters defined in window.
    summarizer: either str corresponding to pandas window function, currently
            * "sum",
            * "mean",
            * "median",
            * "std",
            * "var",
            * "kurt",
            * "min",
            * "max",
            * "corr",
            * "cov",
            * "skew",
            * "sem"
         or custom function call. See for the native window functions also
         https://pandas.pydata.org/docs/reference/window.html.
    window: list of integers
        List containg window_length and lag parameters, see WindowSummarizer
        class description for in-depth explanation.
    """
    lag = window[0]
    window_length = window[1]

    if summarizer in pd_rolling:
        if isinstance(Z, pd.core.groupby.generic.SeriesGroupBy):
            if bfill is False:
                feat = getattr(
                    Z.shift(lag).rolling(
                        window=window_length, min_periods=window_length
                    ),
                    summarizer,
                )()
            else:
                feat = getattr(
                    Z.shift(lag)
                    .fillna(method="bfill")
                    .rolling(window=window_length, min_periods=window_length),
                    summarizer,
                )()
            feat = pd.DataFrame(feat)
        else:
            if bfill is False:
                feat = Z.apply(
                    lambda x: getattr(
                        x.shift(lag).rolling(
                            window=window_length, min_periods=window_length
                        ),
                        summarizer,
                    )()
                )
            else:
                feat = Z.apply(
                    lambda x: getattr(
                        x.shift(lag)
                        .fillna(method="bfill")
                        .rolling(window=window_length, min_periods=window_length),
                        summarizer,
                    )()
                )
    else:
        if bfill is False:
            feat = Z.shift(lag)
        else:
            feat = Z.shift(lag).fillna(method="bfill")
        if isinstance(Z, pd.core.groupby.generic.SeriesGroupBy) and callable(
            summarizer
        ):
            feat = feat.rolling(window_length).apply(summarizer, raw=True)
        elif not isinstance(Z, pd.core.groupby.generic.SeriesGroupBy) and callable(
            summarizer
        ):
            feat = feat.apply(
                lambda x: x.rolling(
                    window=window_length, min_periods=window_length
                ).apply(summarizer, raw=True)
            )
        feat = pd.DataFrame(feat)
    if bfill is True:
        feat = feat.fillna(method="bfill")

    if callable(summarizer):
        name = summarizer.__name__
    else:
        name = summarizer

    if name == "lag":
        feat.rename(
            columns={feat.columns[0]: name + "_" + str(window[0])},
            inplace=True,
        )
    else:
        feat.rename(
            columns={
                feat.columns[0]: name + "_" + "_".join([str(item) for item in window])
            },
            inplace=True,
        )
    return feat


ALLOWED_SUM_FUNCS = [
    "mean",
    "min",
    "max",
    "median",
    "sum",
    "skew",
    "kurt",
    "var",
    "std",
    "sem",
    "nunique",
    "count",
]


def _check_summary_function(summary_function):
    """Validate summary_function.

    Parameters
    ----------
    summary_function : str, list or tuple
        Either a string or list/tuple of strings indicating the pandas summary
        functions ("mean", "min", "max", "median", "sum", "skew", "kurtosis",
        "var", "std", "sem", "nunique", "count") that is used to summarize
        each column of the dataset.

    Returns
    -------
    summary_function : list or tuple
        The summary functions that will be used to summarize the dataset.
    """
    msg = f"""`summary_function` must be None, or str or a list or tuple made up of
          {ALLOWED_SUM_FUNCS}.
          """
    if isinstance(summary_function, str):
        if summary_function not in ALLOWED_SUM_FUNCS:
            raise ValueError(msg)
        summary_function = [summary_function]
    elif isinstance(summary_function, (list, tuple)):
        if not all([func in ALLOWED_SUM_FUNCS for func in summary_function]):
            raise ValueError(msg)
    elif summary_function is not None:
        raise ValueError(msg)
    return summary_function


def _check_quantiles(quantiles):
    """Validate quantiles.

    Parameters
    ----------
    quantiles : str, list, tuple or None
        Either a string or list/tuple of strings indicating the pandas summary
        functions ("mean", "min", "max", "median", "sum", "skew", "kurtosis",
        "var", "std", "sem", "nunique", "count") that is used to summarize
        each column of the dataset.

    Returns
    -------
    quantiles : list or tuple
        The validated quantiles that will be used to summarize the dataset.
    """
    msg = """`quantiles` must be None, int, float or a list or tuple made up of
          int and float values that are between 0 and 1.
          """
    if isinstance(quantiles, (int, float)):
        if not 0.0 <= quantiles <= 1.0:
            raise ValueError(msg)
        quantiles = [quantiles]
    elif isinstance(quantiles, (list, tuple)):
        if len(quantiles) == 0 or not all(
            [isinstance(q, (int, float)) and 0.0 <= q <= 1.0 for q in quantiles]
        ):
            raise ValueError(msg)
    elif quantiles is not None:
        raise ValueError(msg)
    return quantiles


class SummaryTransformer(BaseTransformer):
    """Calculate summary value of a time series.

    For :term:`univariate time series` a combination of summary functions and
    quantiles of the input series are calculated. If the input is a
    :term:`multivariate time series` then the summary functions and quantiles
    are calculated separately for each column.

    Parameters
    ----------
    summary_function : str, list, tuple, or None, default=("mean", "std", "min", "max")
        If not None, a string, or list or tuple of strings indicating the pandas
        summary functions that are used to summarize each column of the dataset.
        Must be one of ("mean", "min", "max", "median", "sum", "skew", "kurt",
        "var", "std", "sem", "nunique", "count").
        If None, no summaries are calculated, and quantiles must be non-None.
    quantiles : str, list, tuple or None, default=(0.1, 0.25, 0.5, 0.75, 0.9)
        Optional list of series quantiles to calculate. If None, no quantiles
        are calculated, and summary_function must be non-None.
    flatten_transform_index : bool, default=True
        if True, columns of return DataFrame are flat, by "variablename__feature"
        if False, columns are MultiIndex (variablename__feature)
        has no effect if return type is one without column names

    See Also
    --------
    WindowSummarizer:
        Extracting features across (shifted) windows from series

    Notes
    -----
    This provides a wrapper around pandas DataFrame and Series agg and
    quantile methods.
    """

    _tags = {
        "input_data_type": "Series",
        # what is the abstract type of X: Series, or Panel
        "output_data_type": "Primitives",
        # what abstract type is returned: Primitives, Series, Panel
        "instancewise": True,  # is this an instance-wise transform?
        "X_inner_type": ["pd.DataFrame", "pd.Series"],
        "y_inner_type": "None",
        "fit_is_empty": True,
    }

    def __init__(
        self,
        summary_function=("mean", "std", "min", "max"),
        quantiles=(0.1, 0.25, 0.5, 0.75, 0.9),
        flatten_transform_index=True,
    ):
        self.summary_function = summary_function
        self.quantiles = quantiles
        self.flatten_transform_index = flatten_transform_index

        super().__init__()

    def _transform(self, X, y=None):
        """Transform X and return a transformed version.

        private _transform containing the core logic, called from transform

        Parameters
        ----------
        X : pd.Series or pd.DataFrame
            Data to be transformed
        y : ignored argument for interface compatibility
            Additional data, e.g., labels for transformation

        Returns
        -------
        summary_value : scalar or pd.Series
            If `series_or_df` is univariate then a scalar is returned. Otherwise,
            a pd.Series is returned.
        """
        if self.summary_function is None and self.quantiles is None:
            raise ValueError(
                "One of `summary_function` and `quantiles` must not be None."
            )
        summary_function = _check_summary_function(self.summary_function)
        quantiles = _check_quantiles(self.quantiles)

        if summary_function is not None:
            summary_value = X.agg(summary_function)

        if quantiles is not None:
            quantile_value = X.quantile(quantiles)
            quantile_value.index = [str(s) for s in quantile_value.index]

        if summary_function is not None and quantiles is not None:
            summary_value = pd.concat([summary_value, quantile_value])
        elif summary_function is None:
            summary_value = quantile_value

        if isinstance(X, pd.Series):
            summary_value.name = X.name
            summary_value = pd.DataFrame(summary_value)

        Xt = summary_value.T

        if len(Xt) > 1:
            # move the row index as second level to column
            Xt = pd.DataFrame(Xt.T.unstack()).T
            if self.flatten_transform_index:
                Xt.columns = flatten_multiindex(Xt.columns)

        return Xt

    @classmethod
    def get_test_params(cls, parameter_set="default"):
        """Return testing parameter settings for the estimator.

        Parameters
        ----------
        parameter_set : str, default="default"
            Name of the set of test parameters to return, for use in tests. If no
            special parameters are defined for a value, will return `"default"` set.

        Returns
        -------
        params : dict or list of dict, default = {}
            Parameters to create testing instances of the class
            Each dict are parameters to construct an "interesting" test instance, i.e.,
            `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
            `create_test_instance` uses the first (or only) dictionary in `params`
        """
        params1 = {}
        params2 = {"summary_function": ["mean", "std", "skew"], "quantiles": None}
        params3 = {"summary_function": None, "quantiles": (0.1, 0.2, 0.25)}

        return [params1, params2, params3]


class PlateauFinder(BaseTransformer):
    """
    Plateau finder transformer.

    Transformer that finds segments of the same given value, plateau in
    the time series, and returns the starting indices and lengths.

    Parameters
    ----------
    value : {int, float, np.nan, np.inf}
        Value for which to find segments
    min_length : int
        Minimum lengths of segments with same value to include.
        If min_length is set to 1, the transformer can be used as a value
        finder.
    """

    _tags = {
        "fit_is_empty": True,
        "capability:multivariate": False,
        "output_data_type": "Series",
        "instancewise": False,
        "X_inner_type": "numpy3D",
        "y_inner_type": "None",
    }

    def __init__(self, value=np.nan, min_length=2):
        self.value = value
        self.min_length = min_length
        super().__init__(_output_convert=False)

    def _transform(self, X, y=None):
        """Transform X.

        Parameters
        ----------
        X : numpy3D array shape (n_cases, 1, n_timepoints)

        Returns
        -------
        X : pandas data frame
        """
        _starts = []
        _lengths = []

        # find plateaus (segments of the same value)
        for x in X[:, 0]:
            # find indices of transition
            if np.isnan(self.value):
                i = np.where(np.isnan(x), 1, 0)

            elif np.isinf(self.value):
                i = np.where(np.isinf(x), 1, 0)

            else:
                i = np.where(x == self.value, 1, 0)

            # pad and find where segments transition
            transitions = np.diff(np.hstack([0, i, 0]))

            # compute starts, ends and lengths of the segments
            starts = np.where(transitions == 1)[0]
            ends = np.where(transitions == -1)[0]
            lengths = ends - starts

            # filter out single points
            starts = starts[lengths >= self.min_length]
            lengths = lengths[lengths >= self.min_length]

            _starts.append(starts)
            _lengths.append(lengths)

        # put into dataframe
        Xt = pd.DataFrame()
        column_prefix = "{}_{}".format(
            "channel_",
            "nan" if np.isnan(self.value) else str(self.value),
        )
        Xt["%s_starts" % column_prefix] = pd.Series(_starts)
        Xt["%s_lengths" % column_prefix] = pd.Series(_lengths)

        Xt = Xt.applymap(lambda x: pd.Series(x))
        return Xt


class FittedParamExtractor(BaseTransformer):
    """Fitted parameter extractor.

    Extract parameters of a fitted forecaster as features for a subsequent
    tabular learning task.
    This class first fits a forecaster to the given time series and then
    returns the fitted parameters.
    The fitted parameters can be used as features for a tabular estimator
    (e.g. classification).

    Parameters
    ----------
    forecaster : estimator object
        An aeon estimator to extract features from.
    param_names : str
        Name of parameters to extract from the forecaster.
    n_jobs : int, default=None
        Number of jobs to run in parallel.
        None means 1 unless in a joblib.parallel_backend context.
        -1 means using all processors.
    """

    _tags = {
        "fit_is_empty": True,
        "capability:multivariate": False,
        "input_data_type": "Series",
        # what is the abstract type of X: Series, or Panel
        "output_data_type": "Primitives",
        # what is the abstract type of y: None (not needed), Primitives, Series, Panel
        "instancewise": True,
        "X_inner_type": "numpy3D",
        "y_inner_type": "None",
    }

    def __init__(self, forecaster, param_names, n_jobs=None):
        self.forecaster = forecaster
        self.param_names = param_names
        self.n_jobs = n_jobs
        super().__init__(_output_convert=True)

    def _transform(self, X, y=None):
        """Transform X.

        Parameters
        ----------
        X: np.ndarray shape (n_cases, 1, n_timepoints)
            The training input samples.
        y : ignored argument for interface compatibility
            Additional data, e.g., labels for transformation

        Returns
        -------
        Xt : pd.DataFrame
            Extracted parameters; columns are parameter values
        """
        param_names = self._check_param_names(self.param_names)
        n_cases = X.shape[0]

        def _fit_extract(forecaster, x, param_names):
            forecaster.fit(x)
            params = forecaster.get_fitted_params()
            return np.hstack([params[name] for name in param_names])

        def _get_instance(X, key):
            # assuming univariate data
            if isinstance(X, pd.DataFrame):
                return X.iloc[key, 0]
            else:
                return pd.Series(X[key, 0])

        # iterate over rows
        extracted_params = Parallel(n_jobs=self.n_jobs)(
            delayed(_fit_extract)(
                self.forecaster.clone(), _get_instance(X, i), param_names
            )
            for i in range(n_cases)
        )

        return pd.DataFrame(extracted_params, columns=param_names)

    @staticmethod
    def _check_param_names(param_names):
        if isinstance(param_names, str):
            param_names = [param_names]
        elif isinstance(param_names, (list, tuple)):
            for param in param_names:
                if not isinstance(param, str):
                    raise ValueError(
                        f"All elements of `param_names` must be strings, "
                        f"but found: {type(param)}"
                    )
        else:
            raise ValueError(
                f"`param_names` must be str, or a list or tuple of strings, "
                f"but found: {type(param_names)}"
            )
        return param_names

    @classmethod
    def get_test_params(cls, parameter_set="default"):
        """Return testing parameter settings for the estimator.

        Parameters
        ----------
        parameter_set : str, default="default"
            Name of the set of test parameters to return, for use in tests. If no
            special parameters are defined for a value, will return `"default"` set.

        Returns
        -------
        params : dict or list of dict, default = {}
            Parameters to create testing instances of the class
            Each dict are parameters to construct an "interesting" test instance, i.e.,
            `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
            `create_test_instance` uses the first (or only) dictionary in `params`
        """
        from aeon.forecasting.trend import TrendForecaster

        # accessing a nested parameter
        params = [
            {
                "forecaster": TrendForecaster(),
                "param_names": ["regressor__intercept"],
            }
        ]
        return params
