Skip to content

multiprocess-tests-schedule-run #239

multiprocess-tests-schedule-run

multiprocess-tests-schedule-run #239

name: multiprocess-tests-schedule-run
on:
# continuous
schedule:
# Run every 4 hour
- cron: "0 */4 * * *"
permissions:
contents: read
actions: write # to cancel previous workflows
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true
jobs:
multiprocess-checkpoint-benchmarks:
name: "multiprocess-checkpoint-benchmarks (Python ${{ matrix.python-version }}, jax=${{ matrix.jax-version }})"
runs-on: linux-g2-16-l4-1gpu-x4
# runs-on: linux-x86-ct5lp-4tpu-x4
container: us-docker.pkg.dev/ml-oss-artifacts-published/ml-public-container/ml-build:infrastructure-public-image-2d2a7b1e6e2e
defaults:
run:
working-directory: checkpoint
strategy:
matrix:
python-version: ["3.12"]
jax-version: ["0.6.0"]
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install -e .
pip install -e .[testing,gcs] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip uninstall -y orbax
if [[ "${{ matrix.jax-version }}" == "newest" ]]; then
pip install -U jax[k8s,cuda12] jaxlib
elif [[ "${{ matrix.jax-version }}" == "nightly" ]]; then
pip install -U --pre jax[k8s,cuda12] jaxlib --extra-index-url https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/
else
pip install "jax[k8s,cuda12]>=${{ matrix.jax-version }}" "jaxlib>=${{ matrix.jax-version }}"
fi
pip install gcsfs
pip install portpicker
- name: Run correctness integration tests
env:
GCS_BUCKET_PATH: gs://orbax-benchmarks/benchmark-results/${{ github.run_id }}
TF_FORCE_GPU_ALLOW_GROWTH: true
XLA_PYTHON_CLIENT_PREALLOCATE: false
run: |
cd orbax/checkpoint/_src/testing/benchmarks
failed_benchmarks=""
benchmark_configs_file="multiprocess_benchmark_configs.txt"
echo "Running benchmarks specified in $benchmark_configs_file"
while IFS= read -r entry || [ -n "$entry" ]; do
if [ -n "$entry" ]; then
echo "Running benchmark for $entry"
if ! python run_benchmarks.py --config_file="$entry" --output_directory=$GCS_BUCKET_PATH; then
echo "Benchmark $entry failed"
failed_benchmarks="$failed_benchmarks $entry"
fi
fi
done < "$benchmark_configs_file"
cd ../../../../..
if [ -n "$failed_benchmarks" ]; then
echo "The following benchmarks failed:$failed_benchmarks"
exit 1
fi
# python -m pytest orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py
# cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
# The below step just reports the success or failure of tests as a "commit status".
# This is needed for copybara integration.
- name: Run multiprocess tests
env:
TEST_TMPDIR: gs://orbax-benchmarks/unit-tests/${{ github.run_id }}
run: |
python -c "import jax; jax.distributed.initialize(); print(jax.devices()); import pytest; test_files = [line.strip() for line in open('orbax/checkpoint/_src/testing/multiprocess_tests.txt') if line.strip()]; pytest.main(['-c', '/dev/null'] + test_files)"
# python -m pytest orbax/checkpoint/_src/handlers/array_checkpoint_handler_test.py;"
# cd orbax/checkpoint/_src/testing/benchmarks && python run_benchmarks.py --config_file=configs/pytree_checkpoint_benchmark.yaml --output_directory=$GCS_BUCKET_PATH
# python -m pytest orbax/checkpoint/checkpoint_manager_test.py
- name: Report success or failure as github status
if: always()
shell: bash
run: |
status="${{ job.status }}"
lowercase_status=$(echo $status | tr '[:upper:]' '[:lower:]')
curl -sS --request POST \
--url https://api.github.com/repos/${{ github.repository }}/statuses/${{ github.sha }} \
--header 'authorization: Bearer ${{ secrets.GITHUB_TOKEN }}' \
--header 'content-type: application/json' \
--data '{
"state": "'$lowercase_status'",
"target_url": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}",
"description": "'$status'",
"context": "github-actions/build"
}'