// Copyright (c) 2024-2025 The Regents of the University of Michigan.
// Part of row, released under the BSD 3-Clause License.

use log::{debug, error, trace, warn};
use serde_json::Value;
use std::collections::{HashMap, HashSet};
use std::fmt::Write as _;
use std::io::Write;
use std::os::unix::process::ExitStatusExt;
use std::path::{Path, PathBuf};
use std::process::{Child, Command, Stdio};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::{str, thread};

use crate::Error;
use crate::cluster::Cluster;
use crate::launcher::Launcher;
use crate::scheduler::bash::BashScriptBuilder;
use crate::scheduler::{ActiveJobs, Scheduler};
use crate::workflow::Action;

/// The `Slurm` scheduler constructs bash scripts and executes them with `sbatch`.
pub struct Slurm {
    cluster: Cluster,
    launchers: HashMap<String, Launcher>,
}

impl Slurm {
    /// Construct a new Slurm scheduler.
    pub fn new(cluster: Cluster, launchers: HashMap<String, Launcher>) -> Self {
        Self { cluster, launchers }
    }

    fn write_mem_per(
        preamble: &mut String,
        action_mem: Option<usize>,
        partition_mem: Option<usize>,
        processor_type: &str,
        action_name: &str,
    ) -> Result<(), Error> {
        match (action_mem, partition_mem) {
            (None, Some(mem)) | (Some(mem), None) => {
                let _ = writeln!(preamble, "#SBATCH --mem-per-{processor_type}={mem}M");
            }
            (Some(mem_action), Some(mem_partition)) => {
                if mem_action < mem_partition {
                    warn!(
                        "Omit `memory_per_{processor_type}_mb` in action '{action_name}' to request more memory at no cost."
                    );
                    let _ = writeln!(preamble, "#SBATCH --mem-per-{processor_type}={mem_action}M");
                } else {
                    return Err(Error::TooMuchMemory(action_name.into(), mem_action));
                }
            }
            (None, None) => {}
        }

        Ok(())
    }
}

/** Track the running squeue process

Or `None` when no process was launched.
*/
pub struct ActiveSlurmJobs {
    squeue: Option<Child>,
    max_jobs: usize,
}

impl Scheduler for Slurm {
    #[allow(clippy::too_many_lines)]
    fn make_script(
        &self,
        action: &Action,
        directories: &[PathBuf],
        workspace_path: &Path,
        directory_values: &HashMap<PathBuf, Value>,
    ) -> Result<String, Error> {
        let mut preamble = String::with_capacity(512);
        let mut user_partition = &None;

        write!(preamble, "#SBATCH --job-name={}", action.name()).expect("valid format");
        let _ = match directories.first() {
            Some(directory) => match directories.len() {
                0..=1 => writeln!(preamble, "-{}", directory.display()),
                _ => writeln!(
                    preamble,
                    "-{}+{}",
                    directory.display(),
                    directories.len() - 1
                ),
            },
            None => writeln!(preamble),
        };

        // The output file directory and filename
        let output_path = action.submit_options.get(&self.cluster.name).map_or_else(
            || format!("{}-%j.out", action.name()),
            |submit_options| {
                let mut path = submit_options
                    .output_file_path
                    .as_deref()
                    .unwrap_or("")
                    .to_string();
                if !path.is_empty() && !path.ends_with('/') {
                    path.push('/');
                }

                match submit_options.output_file_name {
                    None => format!("{path}{}-%j.out", action.name()),
                    Some(ref name) => {
                        let replaced_name = name.replace("{action_name}", action.name());
                        format!("{path}{replaced_name}")
                    }
                }
            },
        );
        let _ = writeln!(preamble, "#SBATCH --output={output_path}");

        if let Some(submit_options) = action.submit_options.get(&self.cluster.name) {
            user_partition = &submit_options.partition;
        }

        // The partition
        let partition = self.cluster.find_partition(
            user_partition.as_deref(),
            &action.resources,
            directories.len(),
        )?;
        let _ = writeln!(preamble, "#SBATCH --partition={}", partition.name);

        // Resources
        let _ = writeln!(
            preamble,
            "#SBATCH --ntasks={}",
            action.resources.total_processes(directories.len())
        );

        if let Some(threads_per_process) = action.resources.threads_per_process {
            let _ = writeln!(preamble, "#SBATCH --cpus-per-task={threads_per_process}");
        }
        if let Some(gpus_per_process) = action.resources.gpus_per_process {
            let _ = writeln!(preamble, "#SBATCH --gpus-per-task={gpus_per_process}");

            if let Some(ref gpus_per_node) = partition.gpus_per_node {
                let n_nodes = action
                    .resources
                    .total_gpus(directories.len())
                    .div_ceil(*gpus_per_node);
                let _ = writeln!(preamble, "#SBATCH --nodes={n_nodes}");
            }

            Slurm::write_mem_per(
                &mut preamble,
                action.resources.memory_per_gpu_mb,
                partition.memory_per_gpu_mb,
                "gpu",
                action.name(),
            )?;
        } else {
            if let Some(ref cpus_per_node) = partition.cpus_per_node {
                let n_nodes = action
                    .resources
                    .total_cpus(directories.len())
                    .div_ceil(*cpus_per_node);
                let _ = writeln!(preamble, "#SBATCH --nodes={n_nodes}");
            }

            Slurm::write_mem_per(
                &mut preamble,
                action.resources.memory_per_cpu_mb,
                partition.memory_per_cpu_mb,
                "cpu",
                action.name(),
            )?;
        }

        // Slurm doesn't store times in seconds, so round up to the nearest minute.
        let total = action
            .resources
            .total_walltime(directories.len())
            .signed_total_seconds();
        let minutes = (total + 59) / 60;
        let _ = writeln!(preamble, "#SBATCH --time={minutes}");

        // Add global cluster submit options first so that users can override them.
        for option in &self.cluster.submit_options {
            let _ = writeln!(preamble, "#SBATCH {option}");
        }

        // Use provided submission options
        if let Some(submit_options) = action.submit_options.get(&self.cluster.name) {
            if let Some(ref account) = submit_options.account {
                if let Some(ref suffix) = partition.account_suffix {
                    let _ = writeln!(preamble, "#SBATCH --account={account}{suffix}");
                } else {
                    let _ = writeln!(preamble, "#SBATCH --account={account}");
                }
            }
            for option in &submit_options.custom {
                let _ = writeln!(preamble, "#SBATCH {option}");
            }
        }

        BashScriptBuilder::new(
            &self.cluster.name,
            action,
            directories,
            workspace_path,
            directory_values,
            &self.launchers,
        )
        .with_preamble(&preamble)
        .build()
    }

    fn submit(
        &self,
        workflow_root: &Path,
        action: &Action,
        directories: &[PathBuf],
        workspace_path: &Path,
        directory_values: &HashMap<PathBuf, Value>,
        should_terminate: Arc<AtomicBool>,
    ) -> Result<Option<u32>, Error> {
        debug!("Submtitting '{}' with sbatch.", action.name());

        // output() below is blocking with no convenient way to interrupt it.
        // If the user pressed ctrl-C, let the current call to submit() finish
        // and update the cache. Assuming that there will be a next call to
        // submit(), that next call will return with an Interrupted error before
        // submitting the next job.
        if should_terminate.load(Ordering::Relaxed) {
            error!("Interrupted! Cancelling further job submissions.");
            return Err(Error::Interrupted);
        }

        let script = self.make_script(action, directories, workspace_path, directory_values)?;

        let mut child = Command::new("sbatch")
            .stdin(Stdio::piped())
            .stdout(Stdio::piped())
            .arg("--parsable")
            .current_dir(workflow_root)
            .spawn()
            .map_err(|e| Error::SpawnProcess("sbatch".into(), e))?;

        let mut stdin = child.stdin.take().expect("Piped stdin");
        let input_thread = thread::spawn(move || {
            let _ = write!(stdin, "{script}");
        });

        trace!("Waiting for sbatch to complete.");
        let output = child
            .wait_with_output()
            .map_err(|e| Error::SpawnProcess("sbatch".into(), e))?;

        input_thread.join().expect("The thread should not panic");

        if output.status.success() {
            let job_id_string = str::from_utf8(&output.stdout).expect("Valid UTF-8 output");
            let job_id = job_id_string
                .trim_end_matches(char::is_whitespace)
                .parse::<u32>()
                .map_err(|_| Error::UnexpectedOutput("sbatch".into(), job_id_string.into()))?;
            Ok(Some(job_id))
        } else {
            let message = match output.status.code() {
                None => match output.status.signal() {
                    None => "sbatch was terminated by a unknown signal".to_string(),
                    Some(signal) => format!("sbatch was terminated by signal {signal}"),
                },
                Some(code) => format!("sbatch exited with code {code}"),
            };
            Err(Error::SubmitAction(action.name().into(), message))
        }
    }

    /** Use `squeue` to determine the jobs that are still present in the queue.

    Launch `squeue --jobs job0,job1,job2 -o "%A" --noheader` to determine which of
    these jobs are still in the queue.
    */
    fn active_jobs(&self, jobs: &[u32]) -> Result<Box<dyn ActiveJobs>, Error> {
        if jobs.is_empty() {
            return Ok(Box::new(ActiveSlurmJobs {
                squeue: None,
                max_jobs: 0,
            }));
        }

        debug!("Checking job status with squeue.");

        let mut jobs_string = String::with_capacity(9 * jobs.len());
        // Prefix the --jobs argument with "1,". Otherwise, squeue reports an
        // error when a single job is not in the queue.
        if jobs.len() == 1 {
            jobs_string.push_str("1,");
        }
        for job in jobs {
            let _ = write!(jobs_string, "{job},");
        }

        let squeue = Command::new("squeue")
            .stdin(Stdio::piped())
            .stdout(Stdio::piped())
            .stderr(Stdio::piped())
            .arg("--jobs")
            .arg(&jobs_string)
            .args(["-o", "%A"])
            .arg("--noheader")
            .spawn()
            .map_err(|e| Error::SpawnProcess("squeue".into(), e))?;

        Ok(Box::new(ActiveSlurmJobs {
            squeue: Some(squeue),
            max_jobs: jobs.len(),
        }))
    }
}

impl ActiveJobs for ActiveSlurmJobs {
    fn get(self: Box<Self>) -> Result<HashSet<u32>, Error> {
        let mut result = HashSet::with_capacity(self.max_jobs);

        if let Some(squeue) = self.squeue {
            trace!("Waiting for squeue to complete.");
            let output = squeue
                .wait_with_output()
                .map_err(|e| Error::SpawnProcess("sbatch".into(), e))?;

            if !output.status.success() {
                let message = match output.status.code() {
                    None => match output.status.signal() {
                        None => "squeue was terminated by a unknown signal".to_string(),
                        Some(signal) => format!("squeue was terminated by signal {signal}"),
                    },
                    Some(code) => format!("squeue exited with code {code}"),
                };
                return Err(Error::ExecuteSqueue(
                    message,
                    str::from_utf8(&output.stderr).expect("Valid UTF-8").into(),
                ));
            }

            let jobs = str::from_utf8(&output.stdout).expect("Valid UTF-8");
            for job in jobs.lines() {
                result.insert(
                    job.parse()
                        .map_err(|_| Error::UnexpectedOutput("squeue".into(), job.into()))?,
                );
            }
        }

        Ok(result)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use serial_test::parallel;

    use crate::builtin::BuiltIn;
    use crate::cluster::{Cluster, IdentificationMethod, Partition, SchedulerType};
    use crate::launcher;
    use crate::workflow::{Processes, SubmitOptions};

    fn setup() -> (Action, Vec<PathBuf>, Slurm) {
        let action = Action {
            name: Some("action".to_string()),
            command: Some("command {directory}".to_string()),
            launchers: Some(vec!["mpi".into()]),
            ..Action::default()
        };

        let directories = vec![PathBuf::from("a"), PathBuf::from("b"), PathBuf::from("c")];
        let launchers = launcher::Configuration::built_in();
        let cluster = Cluster {
            name: "cluster".into(),
            identify: IdentificationMethod::Always(false),
            scheduler: SchedulerType::Slurm,
            partition: vec![Partition::default()],
            submit_options: Vec::new(),
        };

        let slurm = Slurm::new(cluster, launchers.by_cluster("cluster"));
        (action, directories, slurm)
    }

    #[test]
    #[parallel]
    fn default() {
        let (action, directories, slurm) = setup();
        let script = slurm
            .make_script(&action, &directories, &PathBuf::default(), &HashMap::new())
            .expect("valid script");
        println!("{script}");

        assert!(script.contains("#SBATCH --job-name=action"));
        assert!(script.contains("#SBATCH --ntasks=1"));
        assert!(!script.contains("#SBATCH --account"));
        assert!(script.contains("#SBATCH --partition=partition"));
        assert!(!script.contains("#SBATCH --cpus-per-task"));
        assert!(!script.contains("#SBATCH --gpus-per-task"));
        assert!(script.contains("#SBATCH --time=180"));
    }

    #[test]
    #[parallel]
    fn cluster_submit_options() {
        let (action, directories, mut slurm) = setup();
        slurm.cluster.submit_options = vec!["--option=value".to_string()];

        let script = slurm
            .make_script(&action, &directories, &PathBuf::default(), &HashMap::new())
            .expect("valid script");
        println!("{script}");

        assert!(script.contains("#SBATCH --job-name=action"));
        assert!(script.contains("#SBATCH --ntasks=1"));
        assert!(!script.contains("#SBATCH --account"));
        assert!(script.contains("#SBATCH --partition=partition"));
        assert!(!script.contains("#SBATCH --cpus-per-task"));
        assert!(!script.contains("#SBATCH --gpus-per-task"));
        assert!(script.contains("#SBATCH --time=180"));
        assert!(script.contains("#SBATCH --option=value"));
        assert!(script.contains("#SBATCH --output=action-%j.out"));
    }

    #[test]
    #[parallel]
    fn ntasks() {
        let (mut action, directories, slurm) = setup();

        action.resources.processes = Some(Processes::PerDirectory(3));

        let script = slurm
            .make_script(&action, &directories, &PathBuf::default(), &HashMap::new())
            .expect("valid script");
        println!("{script}");

        assert!(script.contains("#SBATCH --ntasks=9"));
    }

    #[test]
    #[parallel]
    fn account() {
        let (mut action, directories, slurm) = setup();

        action.submit_options.insert(
            "cluster".into(),
            SubmitOptions {
                account: Some("c".into()),
                ..SubmitOptions::default()
            },
        );

        let script = slurm
            .make_script(&action, &directories, &PathBuf::default(), &HashMap::new())
            .expect("valid script");
        println!("{script}");

        assert!(script.contains("#SBATCH --account=c"));
    }

    #[test]
    #[parallel]
    fn output() {
        let (mut action, directories, slurm) = setup();
        //With directory, not filename specified
        action.submit_options.insert(
            "cluster".into(),
            SubmitOptions {
                output_file_path: Some("dir".into()),
                ..SubmitOptions::default()
            },
        );

        let script = slurm
            .make_script(&action, &directories, &PathBuf::default(), &HashMap::new())
            .expect("valid script");
        println!("{script}");

        assert!(script.contains("#SBATCH --output=dir/action-%j.out"));

        //With both directory and filename specified
        action
            .submit_options
            .entry("cluster".into())
            .and_modify(|submit_options| {
                submit_options.output_file_name = Some("fname_{action_name}.out".into());
            });

        let script = slurm
            .make_script(&action, &directories, &PathBuf::default(), &HashMap::new())
            .expect("valid script");
        println!("{script}");

        assert!(script.contains("#SBATCH --output=dir/fname_action.out"));

        //With filename, but not file directory specified
        action
            .submit_options
            .entry("cluster".into())
            .and_modify(|submit_options| submit_options.output_file_path = None);

        let script = slurm
            .make_script(&action, &directories, &PathBuf::default(), &HashMap::new())
            .expect("valid script");
        println!("{script}");

        assert!(script.contains("#SBATCH --output=fname_action.out"));

        //With both filename and directory specified, directory ending in /
        action
            .submit_options
            .entry("cluster".into())
            .and_modify(|submit_options| submit_options.output_file_path = Some("dir/".into()));

        let script = slurm
            .make_script(&action, &directories, &PathBuf::default(), &HashMap::new())
            .expect("valid script");
        println!("{script}");

        assert!(script.contains("#SBATCH --output=dir/fname_action.out"));
    }

    #[test]
    #[parallel]
    fn custom() {
        let (mut action, directories, slurm) = setup();

        action.submit_options.insert(
            "cluster".into(),
            SubmitOptions {
                custom: vec!["custom0".into(), "custom1".into()],
                ..SubmitOptions::default()
            },
        );

        let script = slurm
            .make_script(&action, &directories, &PathBuf::default(), &HashMap::new())
            .expect("valid script");
        println!("{script}");

        assert!(script.contains("#SBATCH custom0"));
        assert!(script.contains("#SBATCH custom1"));
    }

    #[test]
    #[parallel]
    fn cpus_per_task() {
        let (mut action, directories, slurm) = setup();

        action.resources.threads_per_process = Some(5);

        let script = slurm
            .make_script(&action, &directories, &PathBuf::default(), &HashMap::new())
            .expect("valid script");
        println!("{script}");

        assert!(script.contains("#SBATCH --cpus-per-task=5"));
    }

    #[test]
    #[parallel]
    fn gpus_per_task() {
        let (mut action, directories, slurm) = setup();

        action.resources.gpus_per_process = Some(5);

        let script = slurm
            .make_script(&action, &directories, &PathBuf::default(), &HashMap::new())
            .expect("valid script");
        println!("{script}");

        assert!(script.contains("#SBATCH --gpus-per-task=5"));
    }

    #[test]
    #[parallel]
    fn mem_per_cpu() {
        let (mut action, directories, _) = setup();

        let launchers = launcher::Configuration::built_in();
        let cluster = Cluster {
            name: "cluster".into(),
            identify: IdentificationMethod::Always(false),
            scheduler: SchedulerType::Slurm,
            submit_options: Vec::new(),
            partition: vec![Partition {
                memory_per_cpu_mb: Some(5),
                ..Partition::default()
            }],
        };

        let slurm = Slurm::new(cluster, launchers.by_cluster("cluster"));

        let script = slurm
            .make_script(&action, &directories, &PathBuf::default(), &HashMap::new())
            .expect("valid script");
        println!("{script}");

        assert!(script.contains("#SBATCH --mem-per-cpu=5M"));

        action.resources.memory_per_cpu_mb = Some(2);

        let script = slurm
            .make_script(&action, &directories, &PathBuf::default(), &HashMap::new())
            .expect("valid script");
        println!("{script}");

        assert!(script.contains("#SBATCH --mem-per-cpu=2M"));

        action.resources.memory_per_cpu_mb = Some(10);

        assert!(matches!(
            slurm.make_script(&action, &directories, &PathBuf::default(), &HashMap::new()),
            Err(Error::TooMuchMemory(_, _))
        ));
    }

    #[test]
    #[parallel]
    fn mem_per_gpu() {
        let (mut action, directories, _) = setup();

        let launchers = launcher::Configuration::built_in();
        let cluster = Cluster {
            name: "cluster".into(),
            identify: IdentificationMethod::Always(false),
            scheduler: SchedulerType::Slurm,
            submit_options: Vec::new(),
            partition: vec![Partition {
                memory_per_gpu_mb: Some(12),
                ..Partition::default()
            }],
        };

        let slurm = Slurm::new(cluster, launchers.by_cluster("cluster"));

        action.resources.gpus_per_process = Some(1);

        let script = slurm
            .make_script(&action, &directories, &PathBuf::default(), &HashMap::new())
            .expect("valid script");
        println!("{script}");

        assert!(script.contains("#SBATCH --mem-per-gpu=12M"));

        action.resources.memory_per_gpu_mb = Some(4);

        let script = slurm
            .make_script(&action, &directories, &PathBuf::default(), &HashMap::new())
            .expect("valid script");
        println!("{script}");

        assert!(script.contains("#SBATCH --mem-per-gpu=4M"));

        action.resources.memory_per_gpu_mb = Some(20);

        assert!(matches!(
            slurm.make_script(&action, &directories, &PathBuf::default(), &HashMap::new()),
            Err(Error::TooMuchMemory(_, _))
        ));
    }

    #[test]
    #[parallel]
    fn cpus_per_node() {
        let (mut action, directories, _) = setup();

        let launchers = launcher::Configuration::built_in();
        let cluster = Cluster {
            name: "cluster".into(),
            identify: IdentificationMethod::Always(false),
            scheduler: SchedulerType::Slurm,
            submit_options: Vec::new(),
            partition: vec![Partition {
                cpus_per_node: Some(10),
                ..Partition::default()
            }],
        };

        let slurm = Slurm::new(cluster, launchers.by_cluster("cluster"));

        action.resources.processes = Some(Processes::PerSubmission(81));

        let script = slurm
            .make_script(&action, &directories, &PathBuf::default(), &HashMap::new())
            .expect("valid script");
        println!("{script}");

        assert!(script.contains("#SBATCH --nodes=9"));
    }

    #[test]
    #[parallel]
    fn gpus_per_node() {
        let (mut action, directories, _) = setup();

        let launchers = launcher::Configuration::built_in();
        let cluster = Cluster {
            name: "cluster".into(),
            identify: IdentificationMethod::Always(false),
            scheduler: SchedulerType::Slurm,
            submit_options: Vec::new(),
            partition: vec![Partition {
                gpus_per_node: Some(5),
                ..Partition::default()
            }],
        };

        let slurm = Slurm::new(cluster, launchers.by_cluster("cluster"));

        action.resources.processes = Some(Processes::PerSubmission(81));
        action.resources.gpus_per_process = Some(1);

        let script = slurm
            .make_script(&action, &directories, &PathBuf::default(), &HashMap::new())
            .expect("valid script");
        println!("{script}");

        assert!(script.contains("#SBATCH --nodes=17"));
    }
}
