# SPDX-FileCopyrightText: 2024 Google LLC
# SPDX-License-Identifier: Apache-2.0

""" Plot the stats contained in the .csv file generated by the test_kraepelin_algorithm
unit test (when STATS_FILE_NAME is defined).
"""

import argparse
import csv
import datetime
import json
import logging
import os
import struct
import sys
import matplotlib.pyplot as pyplot


##################################################################################################
def plot_stats(args, rows, stat_name):
    # Get the vmc and score for each of the stepping epochs
    stepping_vmcs = []
    stepping_scores = []
    non_stepping_vmcs = []
    non_stepping_scores = []
    for row in rows:
        if int(row['epoch_type']) == 0:
            non_stepping_vmcs.append(int(row['vmc']))
            non_stepping_scores.append(int(row[stat_name]))
        elif int(row['epoch_type']) == 2:
            stepping_vmcs.append(int(row['vmc']))
            stepping_scores.append(int(row[stat_name]))

    pyplot.plot(stepping_vmcs, stepping_scores, 'go',
                non_stepping_vmcs, non_stepping_scores, 'ro')
    pyplot.show()


##################################################################################################
if __name__ == '__main__':
    # Collect our command line arguments
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('infile', help="The input csv file")
    parser.add_argument('--plot', choices=['score_0', 'score_lf', 'total'],
                        default='score_0',
                        help="Which metric to plot against vmc")
    parser.add_argument('--debug', action='store_true', help="Turn on debug logging")
    args = parser.parse_args()

    level = logging.INFO
    if args.debug:
        level = logging.DEBUG
    logging.basicConfig(level=level)

    # Read in the csv file
    col_names = None
    rows = []
    with open(args.infile, 'rb') as csvfile:
        reader = csv.reader(csvfile)
        for row in reader:
            if reader.line_num == 1:
                col_names = [x.strip() for x in row]
            else:
                rows.append(dict(zip(col_names, row)))

    # Plot now
    plot_stats(args, rows, args.plot)
