package plugins

import (
	"fmt"
	"strconv"
	"sync"
	"time"

	"github.com/rs/zerolog/log"
	"github.com/slack-go/slack"
	"github.com/spf13/cobra"
)

const (
	slackTokenFlag            = "token"
	slackTeamFlag             = "team"
	slackChannelFlag          = "channel"
	slackBackwardDurationFlag = "duration"
	slackMessagesCountFlag    = "messages-count"
)

const slackDefaultDateFrom = time.Hour * 24 * 14

type SlackPlugin struct {
	Plugin
	Channels
	Token string
}

func (p *SlackPlugin) GetName() string {
	return "slack"
}

var (
	tokenArg            string
	teamArg             string
	channelsArg         []string
	backwardDurationArg time.Duration
	messagesCountArg    int
)

func (p *SlackPlugin) DefineCommand(items chan ISourceItem, errors chan error) (*cobra.Command, error) {
	p.Channels = Channels{
		Items:  items,
		Errors: errors,
		wg:     &sync.WaitGroup{},
	}

	command := &cobra.Command{
		Use:   fmt.Sprintf("%s --%s TOKEN --%s TEAM", p.GetName(), slackTokenFlag, slackTeamFlag),
		Short: "Scan Slack team",
		Long:  "Scan Slack team for sensitive information.",
		Run: func(cmd *cobra.Command, args []string) {
			p.getItems()
			p.wg.Wait()
			close(items)
		},
	}

	command.Flags().StringVar(&tokenArg, slackTokenFlag, "", "Slack token [required]")
	err := command.MarkFlagRequired(slackTokenFlag)
	if err != nil {
		return nil, fmt.Errorf("error while marking flag %s as required: %w", slackTokenFlag, err)
	}
	command.Flags().StringVar(&teamArg, slackTeamFlag, "", "Slack team name or ID [required]")
	err = command.MarkFlagRequired(slackTeamFlag)
	if err != nil {
		return nil, fmt.Errorf("error while marking flag %s as required: %w", slackTeamFlag, err)
	}
	command.Flags().StringSliceVar(&channelsArg, slackChannelFlag, []string{}, "Slack channels to scan")
	command.Flags().DurationVar(&backwardDurationArg, slackBackwardDurationFlag, slackDefaultDateFrom,
		"Slack backward duration for messages (ex: 24h, 7d, 1M, 1y)")
	command.Flags().IntVar(&messagesCountArg, slackMessagesCountFlag, 0, "Slack messages count to scan (0 = all messages)")

	return command, nil
}

func (p *SlackPlugin) getItems() {
	slackApi := slack.New(tokenArg)

	team, err := getTeam(slackApi, teamArg)
	if err != nil {
		p.Errors <- fmt.Errorf("error while getting team: %w", err)
		return
	}

	channels, err := getChannels(slackApi, team.ID, channelsArg)
	if err != nil {
		p.Errors <- fmt.Errorf("error while getting channels for team %s: %w", team.Name, err)
		return
	}
	if len(*channels) == 0 {
		log.Warn().Msgf("No channels found for team %s", team.Name)
		return
	}

	log.Info().Msgf("Found %d channels for team %s", len(*channels), team.Name)
	p.wg.Add(len(*channels))
	for _, channel := range *channels { //nolint:gocritic // rangeValCopy: would need a refactor to use a pointer
		go p.getItemsFromChannel(slackApi, channel)
	}
}

func (p *SlackPlugin) getItemsFromChannel(
	slackApi *slack.Client,
	channel slack.Channel, //nolint:gocritic // hugeParam: channel is heavy but needed
) {
	defer p.wg.Done()
	log.Info().Msgf("Getting items from channel %s", channel.Name)

	cursor := ""
	counter := 0
	for {
		history, err := slackApi.GetConversationHistory(&slack.GetConversationHistoryParameters{
			Cursor:    cursor,
			ChannelID: channel.ID,
		})
		if err != nil {
			p.Errors <- fmt.Errorf("error while getting history for channel %s: %w", channel.Name, err)
			return
		}
		for i := range history.Messages {
			outOfRange, err := isMessageOutOfRange(&history.Messages[i], backwardDurationArg, counter, messagesCountArg)
			if err != nil {
				p.Errors <- fmt.Errorf("error while checking message: %w", err)
				return
			}
			if outOfRange {
				break
			}
			if history.Messages[i].Text != "" {
				url, err := slackApi.GetPermalink(&slack.PermalinkParameters{Channel: channel.ID, Ts: history.Messages[i].Timestamp})
				if err != nil {
					log.Warn().Msgf("Error while getting permalink for message %s: %s", history.Messages[i].Timestamp, err)
					url = fmt.Sprintf("Channel: %s; Message: %s", channel.Name, history.Messages[i].Timestamp)
				}
				p.Items <- item{
					Content: &history.Messages[i].Text,
					ID:      fmt.Sprintf("%s-%s-%s", p.GetName(), channel.ID, history.Messages[i].Timestamp),
					Source:  url,
				}
			}
			counter++
		}
		if history.ResponseMetaData.NextCursor == "" {
			break
		}
		cursor = history.ResponseMetaData.NextCursor
	}
}

// Declare it to be consistent with all comparaisons
var timeNow = time.Now()

func isMessageOutOfRange(
	message *slack.Message,
	backwardDuration time.Duration,
	currentMessagesCount,
	limitMessagesCount int,
) (bool, error) {
	if backwardDuration != 0 {
		timestamp, err := strconv.ParseFloat(message.Timestamp, 64)
		if err != nil {
			return true, fmt.Errorf("error while parsing timestamp: %w", err)
		}
		messageDate := time.Unix(int64(timestamp), 0)
		if messageDate.Before(timeNow.Add(-backwardDuration)) {
			return true, nil
		}
	}
	if limitMessagesCount != 0 && currentMessagesCount >= limitMessagesCount {
		return true, nil
	}
	return false, nil
}

type ISlackClient interface {
	GetConversations(*slack.GetConversationsParameters) ([]slack.Channel, string, error)
	ListTeams(slack.ListTeamsParameters) ([]slack.Team, string, error)
}

func getTeam(slackApi ISlackClient, teamName string) (*slack.Team, error) {
	cursorHolder := ""
	for {
		teams, cursor, err := slackApi.ListTeams(slack.ListTeamsParameters{Cursor: cursorHolder})
		if err != nil {
			return nil, fmt.Errorf("error while getting teams: %w", err)
		}
		for _, team := range teams {
			if team.Name == teamName || team.ID == teamName {
				return &team, nil
			}
		}
		if cursor == "" {
			break
		}
		cursorHolder = cursor
	}
	return nil, fmt.Errorf("team '%s' not found", teamName)
}

func getChannels(slackApi ISlackClient, teamId string, wantedChannels []string) (*[]slack.Channel, error) {
	cursorHolder := ""
	selectedChannels := []slack.Channel{}
	for {
		channels, cursor, err := slackApi.GetConversations(&slack.GetConversationsParameters{
			Cursor: cursorHolder,
			TeamID: teamId,
		})
		if err != nil {
			return nil, fmt.Errorf("error while getting channels: %w", err)
		}
		if len(wantedChannels) == 0 {
			selectedChannels = append(selectedChannels, channels...)
		} else {
			for _, channel := range wantedChannels {
				for i := range channels {
					if channels[i].Name == channel || channels[i].ID == channel {
						selectedChannels = append(selectedChannels, channels[i])
					}
				}
			}
			if len(selectedChannels) == len(wantedChannels) {
				return &selectedChannels, nil
			}
		}
		if cursor == "" {
			return &selectedChannels, nil
		}
		cursorHolder = cursor
	}
}
