#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0

# Standard
import argparse
import os
import sys
import time

# Third Party
import llama_cpp

PROMPT = "Q: What is the answer to the ultimate question of life, the universe, and everything? A:"

parser = argparse.ArgumentParser("test llama GPU/CPU")
parser.add_argument("--model", default=None)
parser.add_argument("--device", choices=["cpu", "gpu"], default="gpu")
parser.add_argument("--quiet", action="store_false", dest="verbose")
parser.add_argument("--repeat", type=int, default=5)
parser.add_argument("--max-tokens", type=int, default=0)
parser.add_argument("--prompt", default=PROMPT)


def main(argv=None):
    args = parser.parse_args(argv)
    print(f"llama_cpp version: {llama_cpp.__version__}")
    print(f"llama supports gpu offload: {llama_cpp.llama_supports_gpu_offload()}")
    # -1: offload all layers to GPU
    n_gpu_layers = -1 if args.device == "gpu" else 0
    print(f"n_gpu_layers: {n_gpu_layers} ({args.device})")
    # prints detected devices to stderr
    llama_cpp.llama_backend_init()

    if args.model is None:
        # First Party
        from instructlab.configuration import read_config

        try:
            cfg = read_config()
        except Exception as e:
            print(f"ERROR: Unable to test model: {e}")
            sys.exit(2)

        args.model = cfg.serve.model_path

    if not os.path.isfile(args.model):
        print(f"ERROR: Model file '{args.model}' is missing.")
        print("Run 'lab init' and 'lab download'.")
        sys.exit(2)

    start = time.monotonic()
    llm = llama_cpp.Llama(
        model_path=args.model,
        n_gpu_layers=n_gpu_layers,
        verbose=args.verbose,
    )
    t = time.monotonic() - start
    print(f"Loaded model in {t:.2f}sec")

    print()
    print(
        f"Prompt: '{args.prompt}' (repeats: {args.repeat}, max_tokens: {args.max_tokens})"
    )
    start = time.monotonic()
    for _ in range(args.repeat):
        result = llm(args.prompt, max_tokens=args.max_tokens)
        for choice in result["choices"]:
            print(choice["text"].strip())
    t = time.monotonic() - start
    print(f"Got {args.repeat} result(s) in {t:.2f}sec")


if __name__ == "__main__":
    main()
