#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# based on Steffen Röcker's gist
# https://gist.github.com/sroecker/f52646023d45ab4e281ddee05d6ef2a5
# Standard
import os

# Third Party
import torch


def _gib(size: int) -> str:
    return "{:.1f} GiB".format(size / 1024**3)


print(f"pytorch version: {torch.__version__}")
print(f"HSA_OVERRIDE_GFX_VERSION={os.environ.get('HSA_OVERRIDE_GFX_VERSION')}")
print(f"HIP_VISIBLE_DEVICES={os.environ.get('HIP_VISIBLE_DEVICES')}")
print(f"GPU available: {torch.cuda.is_available()}")
print(f"NVidia CUDA version: {torch.version.cuda or 'n/a'}")
print(f"AMD ROCm HIP version: {torch.version.hip or 'n/a'}")
print()
if torch.cuda.is_available():
    print(f"device count: {torch.cuda.device_count()}")
    print(f"current device: {torch.cuda.current_device()}")
    for idx in range(torch.cuda.device_count()):
        device = torch.device("cuda", idx)
        name = torch.cuda.get_device_name(device)
        free, total = torch.cuda.mem_get_info(device)
        capmin, capmax = torch.cuda.get_device_capability(device)
        print(
            f"  {device} is '{name}' ({_gib(free)} of {_gib(total)} free, "
            f"capability: {capmin}.{capmax})"
        )
    device = torch.device("cuda", torch.cuda.current_device())
else:
    device = torch.device("cpu")

print()
print(f"Testing with {device}")
print()
print("small matmul")
a = torch.rand(3, 3, device=device)
b = torch.rand(3, 3, device=device)
res = torch.matmul(a, b)
print(res)
print(res.size())

print()
print("larger matmul")
a = torch.rand(1280, 1280, device=device)
b = torch.rand(1280, 1280, device=device)
res = torch.matmul(a, b)
print(res)
print(res.size())

print()
for dtype in [torch.float16, torch.bfloat16]:
    print(f"Testing dtype '{dtype}' on {device}")
    try:
        a = torch.rand(3, 3, device=device, dtype=dtype)
        b = torch.rand(3, 3, device=device, dtype=dtype)
        res = torch.matmul(a, b)
        print(res)
    except Exception as e:
        print(f"dtype {dtype} is not supported: {str(e)[:256]}")
    print()
