// SPDX-License-Identifier: AGPL-3.0-only

package streamingpromql

import (
	"context"
	"testing"
	"time"

	"github.com/prometheus/prometheus/model/timestamp"
	"github.com/prometheus/prometheus/promql/parser"
	"github.com/prometheus/prometheus/promql/promqltest"
	"github.com/stretchr/testify/require"

	"github.com/grafana/mimir/pkg/querier/stats"
	"github.com/grafana/mimir/pkg/streamingpromql/operators/functions"
)

// This test ensures that all functions correctly merge series after dropping the metric name.
func TestFunctionDeduplicateAndMerge(t *testing.T) {
	data := `
		load 30s
			float_a{env="prod"}      _   0 1                       _ _   _ _   _ _   _ _   _ _   _ _
			float_b{env="prod"}      _   _ _                       _ _   _ _   _ _   _ _   _ _   8 9
			histogram_a{env="prod"}  _   {{count:0}} {{count:1}}   _ _   _ _   _ _   _ _   _ _   _ _
			histogram_b{env="prod"}  _   _ _                       _ _   _ _   _ _   _ _   _ _   {{count:8}} {{count:9}}
	`

	storage := promqltest.LoadedStorage(t, data)
	opts := NewTestEngineOpts()
	planner, err := NewQueryPlanner(opts, NewMaximumSupportedVersionQueryPlanVersionProvider())
	require.NoError(t, err)
	engine, err := NewEngine(opts, NewStaticQueryLimitsProvider(0), stats.NewQueryMetrics(nil), planner)
	require.NoError(t, err)

	ctx := context.Background()
	start := timestamp.Time(0).Add(time.Minute)
	end := timestamp.Time(0).Add(7 * time.Minute)
	step := time.Minute

	expressions := map[string]string{
		//lint:sorted
		"abs":                          `abs({__name__=~"float.*"})`,
		"absent":                       `<skip>`,
		"absent_over_time":             `<skip>`,
		"acos":                         `acos({__name__=~"float.*"})`,
		"acosh":                        `acosh({__name__=~"float.*"})`,
		"asin":                         `asin({__name__=~"float.*"})`,
		"asinh":                        `asinh({__name__=~"float.*"})`,
		"atan":                         `atan({__name__=~"float.*"})`,
		"atanh":                        `atanh({__name__=~"float.*"})`,
		"avg_over_time":                `avg_over_time({__name__=~"float.*"}[1m])`,
		"ceil":                         `ceil({__name__=~"float.*"})`,
		"changes":                      `changes({__name__=~"float.*"}[1m])`,
		"clamp":                        `clamp({__name__=~"float.*"}, -Inf, Inf)`,
		"clamp_max":                    `clamp_max({__name__=~"float.*"}, -Inf)`,
		"clamp_min":                    `clamp_min({__name__=~"float.*"}, Inf)`,
		"cos":                          `cos({__name__=~"float.*"})`,
		"cosh":                         `cosh({__name__=~"float.*"})`,
		"count_over_time":              `count_over_time({__name__=~"float.*"}[1m])`,
		"day_of_month":                 `day_of_month({__name__=~"float.*"})`,
		"day_of_week":                  `day_of_week({__name__=~"float.*"})`,
		"day_of_year":                  `day_of_year({__name__=~"float.*"})`,
		"days_in_month":                `days_in_month({__name__=~"float.*"})`,
		"deg":                          `deg({__name__=~"float.*"})`,
		"delta":                        `delta({__name__=~"float.*"}[1m])`,
		"deriv":                        `deriv({__name__=~"float.*"}[1m])`,
		"double_exponential_smoothing": `double_exponential_smoothing({__name__=~"float.*"}[1m], 0.01, 0.1)`,
		"exp":                          `exp({__name__=~"float.*"})`,
		"first_over_time":              `<skip>`, // first_over_time() doesn't drop the metric name, so this test doesn't apply.
		"floor":                        `floor({__name__=~"float.*"})`,
		"histogram_avg":                `histogram_avg({__name__=~"histogram.*"})`,
		"histogram_count":              `histogram_count({__name__=~"histogram.*"})`,
		"histogram_fraction":           `histogram_fraction(0, 0.1, {__name__=~"histogram.*"})`,
		"histogram_quantile":           `histogram_quantile(0.1, {__name__=~"histogram.*"})`,
		"histogram_stddev":             `histogram_stddev({__name__=~"histogram.*"})`,
		"histogram_stdvar":             `histogram_stdvar({__name__=~"histogram.*"})`,
		"histogram_sum":                `histogram_sum({__name__=~"histogram.*"})`,
		"hour":                         `hour({__name__=~"float.*"})`,
		"idelta":                       `idelta({__name__=~"float.*"}[1m])`,
		"increase":                     `increase({__name__=~"float.*"}[1m])`,
		"irate":                        `irate({__name__=~"float.*"}[1m])`,
		"label_join":                   `label_join({__name__=~"float.*"}, "__name__", "", "env")`,
		"label_replace":                `label_replace({__name__=~"float.*"}, "__name__", "$1", "env", "(.*)")`,
		"last_over_time":               `<skip>`, // last_over_time() doesn't drop the metric name, so this test doesn't apply.
		"ln":                           `ln({__name__=~"float.*"})`,
		"log10":                        `log10({__name__=~"float.*"})`,
		"log2":                         `log2({__name__=~"float.*"})`,
		"mad_over_time":                `mad_over_time({__name__=~"float.*"}[1m])`,
		"max_over_time":                `max_over_time({__name__=~"float.*"}[1m])`,
		"min_over_time":                `min_over_time({__name__=~"float.*"}[1m])`,
		"minute":                       `minute({__name__=~"float.*"})`,
		"month":                        `month({__name__=~"float.*"})`,
		"predict_linear":               `predict_linear({__name__=~"float.*"}[1m], 30)`,
		"present_over_time":            `present_over_time({__name__=~"float.*"}[1m])`,
		"quantile_over_time":           `quantile_over_time(0.5, {__name__=~"float.*"}[1m])`,
		"rad":                          `rad({__name__=~"float.*"})`,
		"rate":                         `rate({__name__=~"float.*"}[1m])`,
		"resets":                       `resets({__name__=~"float.*"}[1m])`,
		"round":                        `round({__name__=~"float.*"})`,
		"sgn":                          `sgn({__name__=~"float.*"})`,
		"sin":                          `sin({__name__=~"float.*"})`,
		"sinh":                         `sinh({__name__=~"float.*"})`,
		"sort":                         `<skip>`, // sort*() functions don't drop the metric name, so this test doesn't apply.
		"sort_by_label":                `<skip>`, // sort*() functions don't drop the metric name, so this test doesn't apply.
		"sort_by_label_desc":           `<skip>`, // sort*() functions don't drop the metric name, so this test doesn't apply.
		"sort_desc":                    `<skip>`, // sort*() functions don't drop the metric name, so this test doesn't apply.
		"sqrt":                         `sqrt({__name__=~"float.*"})`,
		"stddev_over_time":             `stddev_over_time({__name__=~"float.*"}[1m])`,
		"stdvar_over_time":             `stdvar_over_time({__name__=~"float.*"}[1m])`,
		"sum_over_time":                `sum_over_time({__name__=~"float.*"}[1m])`,
		"tan":                          `tan({__name__=~"float.*"})`,
		"tanh":                         `tanh({__name__=~"float.*"})`,
		"timestamp":                    `timestamp({__name__=~"float.*"})`,
		"ts_of_first_over_time":        `ts_of_first_over_time({__name__=~"float.*"}[1m])`,
		"ts_of_last_over_time":         `ts_of_last_over_time({__name__=~"float.*"}[1m])`,
		"ts_of_max_over_time":          `ts_of_max_over_time({__name__=~"float.*"}[1m])`,
		"ts_of_min_over_time":          `ts_of_min_over_time({__name__=~"float.*"}[1m])`,
		"vector":                       `<skip>`, // vector() takes a scalar, so this test doesn't apply.
		"year":                         `year({__name__=~"float.*"})`,
	}

	for f, functionMetadata := range functions.RegisteredFunctions {
		name := f.PromQLName()
		expr, haveExpression := expressions[name]

		if expr == "<skip>" || functionMetadata.ReturnType == parser.ValueTypeScalar {
			continue
		}

		t.Run(name, func(t *testing.T) {
			require.Truef(t, haveExpression, "no expression defined for '%s' function", name)

			q, err := engine.NewRangeQuery(ctx, storage, nil, expr, start, end, step)
			require.NoError(t, err)
			defer q.Close()

			mimirResult := q.Exec(ctx)
			require.NoError(t, mimirResult.Err)
			m, err := mimirResult.Matrix()
			require.NoError(t, err)

			require.Len(t, m, 1, "expected a single series")
		})
	}
}
