# SPDX-License-Identifier: Apache-2.0
# Christian Heimes <cheimes@redhat.com>
# Based on Zack Zlotnik's container file

# runtime container has libraries, CLI tools, and virtual env
FROM registry.fedoraproject.org/fedora-toolbox:40 AS runtime
# args and env (default to gfx1100, GFX level 11.0.0, first GPU only)
ARG AMDGPU_ARCH=gfx1100
ARG HSA_OVERRIDE_GFX_VERSION=11.0.0
ARG HIP_VISIBLE_DEVICES=0
ARG PYTORCH_ROCM_VERSION=5.7
# PyTorch 2.2.1 does not support torch_compile with 3.12
ARG PYTHON=python3.11
ENV AMDGPU_ARCH="${AMDGPU_ARCH}" \
    HIP_VISIBLE_DEVICES="${HIP_VISIBLE_DEVICES}" \
    HSA_OVERRIDE_GFX_VERSION="${HSA_OVERRIDE_GFX_VERSION}" \
    PYTORCH_ROCM_VERSION="${PYTORCH_ROCM_VERSION}" \
    PYTHON="${PYTHON}" \
    CLANG_VER="17"
# runtime dependencies (python-devel for torch_compile)
COPY containers/rocm/remove-gfx.sh /tmp/
RUN --mount=type=cache,target=/var/cache/dnf,z \
    dnf install -y --nodocs --setopt=install_weak_deps=False --setopt=keepcache=True \
    rocm-runtime hipblas hiprand hipsparse lld-libs ${PYTHON} ${PYTHON}-devel nvtop radeontop make git gh && \
    /tmp/remove-gfx.sh
# virtual env, umask 0000 to allow end-user to write to venv later
ENV PIP_DISABLE_PIP_VERSION_CHECK=1 \
    VIRTUAL_ENV="/opt/app-root" \
    PS1="(rocm-venv) ${PS1}"
RUN umask 0000 && \
    ${PYTHON} -m venv --upgrade-deps ${VIRTUAL_ENV}
ENV PATH="${VIRTUAL_ENV}/bin:${PATH}" \
    XDG_CACHE_HOME="${VIRTUAL_ENV}/.cache" \
    PIP_CACHE_DIR="${VIRTUAL_ENV}/.cache/pip"
RUN mkdir -p -m777 ${PIP_CACHE_DIR} ${XDG_CACHE_HOME}/huggingface
# SELinux workaround for https://github.com/python/cpython/issues/83074
COPY --chown=0:0 containers/sitecustomize.py ${VIRTUAL_ENV}/lib/${PYTHON}/site-packages
# additional helpers to debug torch and llama
COPY --chown=0:0 containers/bin/debug-* ${VIRTUAL_ENV}/bin


# build env contains compilers and build dependencies
FROM runtime AS buildenv
RUN --mount=type=cache,sharing=locked,target=/var/cache/dnf,z \
    dnf install -y --nodocs --setopt=keepcache=True \
    llvm llvm${CLANG_VER} clang${CLANG_VER} compiler-rt${CLANG_VER} \
    clang-tools-extra lld cmake ninja-build gcc \
    rocblas-devel hip-devel hipblas-devel rocprim-devel rocthrust-devel hipsparse-devel hipcub-devel hiprand-devel


FROM buildenv AS pytorch
# cache downloads in cache mount and don't byte compile Python files
ENV PIP_NO_CACHE_DIR= \
    PIP_NO_COMPILE=1
COPY --chown=0:0 requirements.txt ${VIRTUAL_ENV}/
# pip constraint does not support optional dependencies.
RUN sed 's/\[.*\]//' ${VIRTUAL_ENV}/requirements.txt > ${VIRTUAL_ENV}/constraints.txt
RUN --mount=type=cache,sharing=locked,id=pipcache,target=${PIP_CACHE_DIR},mode=777,z \
    umask 0000 && \
    ${VIRTUAL_ENV}/bin/pip install torch --index-url https://download.pytorch.org/whl/rocm${PYTORCH_ROCM_VERSION} && \
    /tmp/remove-gfx.sh

FROM pytorch AS llama
# remove cached wheel to force rebuild
RUN --mount=type=cache,sharing=locked,id=pipcache,target=${PIP_CACHE_DIR},mode=777,z \
    pip cache remove llama_cpp_python && \
    umask 0000 && \
    CMAKE_ARGS="-DGGML_HIPBLAS=on -DCMAKE_C_COMPILER=/usr/bin/clang-${CLANG_VER} -DCMAKE_CXX_COMPILER=/usr/bin/clang++-${CLANG_VER} -DAMDGPU_TARGETS=${AMDGPU_ARCH}" \
    FORCE_CMAKE=1 \
    ${VIRTUAL_ENV}/bin/pip install -c ${VIRTUAL_ENV}/constraints.txt llama-cpp-python


FROM llama AS bitsandbytes
RUN git clone --depth 1 -b rocm https://github.com/arlo-phoenix/bitsandbytes-rocm-5.6.git /tmp/bitsandbytes
RUN git clone --depth 1 -b rocm-6.0.2 https://github.com/ROCm/hipBLASLt /tmp/hipblaslt
RUN mkdir -p /tmp/bitsandbytes/include/hipblaslt && \
    echo -e '#pragma once\n#ifndef HIPBLASLT_EXPORT\n#define HIPBLASLT_EXPORT\n#endif' > /tmp/bitsandbytes/include/hipblaslt/hipblaslt-export.h && \
    touch /tmp/bitsandbytes/include/hipblaslt/hipblaslt-version.h && \
    cp /tmp/hipblaslt/library/include/* /tmp/bitsandbytes/include/hipblaslt/
RUN cd /tmp/bitsandbytes && \
    make hip ROCM_TARGET="${AMDGPU_ARCH}" && \
    umask 0000 && \
    ${VIRTUAL_ENV}/bin/pip install -c ${VIRTUAL_ENV}/constraints.txt .


# bitsandbytes doesn't compile at the moment
# FROM bitsandbytes AS pip-install
FROM llama AS pip-install
# install from requirements.txt last. pip does not override installed
# packages unless there is a version conflict.
RUN --mount=type=cache,sharing=locked,id=pipcache,target=${PIP_CACHE_DIR},mode=777,z \
    umask 0000 && \
    ${VIRTUAL_ENV}/bin/pip install wheel setuptools-scm && \
    ${VIRTUAL_ENV}/bin/pip install -r ${VIRTUAL_ENV}/requirements.txt


# install instructlab last
FROM pip-install AS instructlab
COPY --chown=0:0 . /tmp/instructlab
RUN --mount=type=cache,sharing=locked,id=pipcache,target=${PIP_CACHE_DIR},mode=777,z \
    umask 0000 && \
    ${VIRTUAL_ENV}/bin/pip install --no-deps /tmp/instructlab[rocm]
RUN find ${VIRTUAL_ENV} -name __pycache__ | xargs rm -rf


# create final image from base runtime, copy virtual env into final stage
FROM runtime as final
COPY --from=instructlab ${VIRTUAL_ENV}/lib/${PYTHON}/site-packages ${VIRTUAL_ENV}/lib/${PYTHON}/site-packages
COPY --from=instructlab ${VIRTUAL_ENV}/bin ${VIRTUAL_ENV}/bin
WORKDIR "${VIRTUAL_ENV}"
LABEL com.github.containers.toolbox="true" \
      com.github.instructlab.instructlab.target="rocm-${AMDGPU_ARCH}" \
      name="instructlab-rocm-base-gfx${GFX_VERSION}" \
      usage="This image is meant to be used with the toolbox(1) command" \
      summary="PyTorch, llama.cpp, and InstructLab dependencies for AMD ROCm GPU ${AMDGPU_ARCH}" \
      maintainer="Christian Heimes <cheimes@redhat.com>" \
      com.redhat.component="instructlab"
